Конкатенация делает обучение неудачным keras - PullRequest
0 голосов
/ 16 октября 2019

Я хочу реализовать архитектуру U-net CNN. В этой архитектуре в части «повышающей дискретизации» есть несколько конкатенаций. Я использую keras 2.1.3, python 2.7 и тензор потока '1.4.0-rc0'

Мои входы имеют форму (6 128 128) (канал первый). Вот код, который я придумал

input_shape = (6, 128, 128)

# Create model U-net Model 
input_fields = Input(shape=input_shape)

f32 = Conv2D(32, (3,3), padding="same")(input_fields)
f32 = Activation("relu", name="f32")(f32)

s32 = Conv2D(32, (3,3), padding="same")(f32)
s32 = Activation("relu",name="s32")(s32) ## To concatenate 32

pool32_64 = MaxPooling2D((2,2), padding="same")(s32)

f64 = Conv2D(64, (3,3), padding="same")(pool32_64)
f64 = Activation("relu")(f64)

s64 = Conv2D(64, (3,3), padding="same")(f64)
s64 = Activation("relu")(s64) # To concatenate 64

pool64_128 = MaxPooling2D((2,2), padding="same")(s64)

f128 = Conv2D(128, (3,3), padding="same")(pool64_128)
f128 = Activation("relu")(f128)

s128 = Conv2D(128, (3,3), padding="same")(f128)
s128 = Activation("relu")(s128)
print "Last shape before Upsampling "s128.get_shape()

#### vvv Upsampling Part vvv  ####

up_128_64 = UpSampling2D((2,2))(s128)
up_128_64 = Conv2D(64, (2,2), padding="same")(up_128_64)
print "Conv2d pu_128_64 ", up_128_64.get_shape()

m64 = Concatenate(axis=0)([s64, up_128_64]) #or concatenate([s64, up_128_64], axis=0)

f64U = Conv2D(64, (3,3), padding="same")(m64)
f64U = Activation("relu")(f64U)
#print "f64U.get_shape()", f64U.get_shape()


s64U = Conv2D(64, (3,3), padding="same")(f64U)
s64U = Activation("relu")(s64U)

up_64_32 = UpSampling2D((2,2))(s64U)
up_64_32 = Conv2D(32, (2,2), padding="same")(up_64_32)

m32 = Concatenate(axis=0)([s32, up_64_32]) # or concatenate([s32, up_64_32], axis=0)

f32U = Conv2D(32, (3,3), padding="same")(m32)
f32U = Activation("relu")(f32U)
print "f32U.get_shape()", f32U.get_shape()

s32U = Conv2D(32, (3,3), padding="same")(f32U)
s32U = Activation("relu")(s32U)


output_field = Conv2D(1, (1,1), padding="same")(s32U)
output_field = Activation("relu")(output_field)

print output_field.get_shape()

U_net = Model(input_fields, output_field)
U_net.summary()

U_net.compile(optimizer="RMSProp", loss="mse")#, metrics=["accuracy"])

U_net.fit(X_train, y_train)

Относительно U_net.summary () вывод:

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            (None, 6, 128, 128)  0                                            
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 32, 128, 128) 1760        input_1[0][0]                    
__________________________________________________________________________________________________
f32 (Activation)                (None, 32, 128, 128) 0           conv2d_1[0][0]                   
__________________________________________________________________________________________________
conv2d_2 (Conv2D)               (None, 32, 128, 128) 9248        f32[0][0]                        
__________________________________________________________________________________________________
s32 (Activation)                (None, 32, 128, 128) 0           conv2d_2[0][0]                   
__________________________________________________________________________________________________
max_pooling2d_1 (MaxPooling2D)  (None, 32, 64, 64)   0           s32[0][0]                        
__________________________________________________________________________________________________
conv2d_3 (Conv2D)               (None, 64, 64, 64)   18496       max_pooling2d_1[0][0]            
__________________________________________________________________________________________________
activation_1 (Activation)       (None, 64, 64, 64)   0           conv2d_3[0][0]                   
__________________________________________________________________________________________________
conv2d_4 (Conv2D)               (None, 64, 64, 64)   36928       activation_1[0][0]               
__________________________________________________________________________________________________
activation_2 (Activation)       (None, 64, 64, 64)   0           conv2d_4[0][0]                   
__________________________________________________________________________________________________
max_pooling2d_2 (MaxPooling2D)  (None, 64, 32, 32)   0           activation_2[0][0]               
__________________________________________________________________________________________________
conv2d_5 (Conv2D)               (None, 128, 32, 32)  73856       max_pooling2d_2[0][0]            
__________________________________________________________________________________________________
activation_3 (Activation)       (None, 128, 32, 32)  0           conv2d_5[0][0]                   
__________________________________________________________________________________________________
conv2d_6 (Conv2D)               (None, 128, 32, 32)  147584      activation_3[0][0]               
__________________________________________________________________________________________________
activation_4 (Activation)       (None, 128, 32, 32)  0           conv2d_6[0][0]                   
__________________________________________________________________________________________________
up_sampling2d_1 (UpSampling2D)  (None, 128, 64, 64)  0           activation_4[0][0]               
__________________________________________________________________________________________________
conv2d_7 (Conv2D)               (None, 64, 64, 64)   32832       up_sampling2d_1[0][0]            
__________________________________________________________________________________________________
concatenate_1 (Concatenate)     (None, 64, 64, 64)   0           activation_2[0][0]               
                                                                 conv2d_7[0][0]                   
__________________________________________________________________________________________________
conv2d_8 (Conv2D)               (None, 64, 64, 64)   36928       concatenate_1[0][0]              
__________________________________________________________________________________________________
activation_5 (Activation)       (None, 64, 64, 64)   0           conv2d_8[0][0]                   
__________________________________________________________________________________________________
conv2d_9 (Conv2D)               (None, 64, 64, 64)   36928       activation_5[0][0]               
__________________________________________________________________________________________________
activation_6 (Activation)       (None, 64, 64, 64)   0           conv2d_9[0][0]                   
__________________________________________________________________________________________________
up_sampling2d_2 (UpSampling2D)  (None, 64, 128, 128) 0           activation_6[0][0]               
__________________________________________________________________________________________________
conv2d_10 (Conv2D)              (None, 32, 128, 128) 8224        up_sampling2d_2[0][0]            
__________________________________________________________________________________________________
concatenate_2 (Concatenate)     (None, 32, 128, 128) 0           s32[0][0]                        
                                                                 conv2d_10[0][0]                  
__________________________________________________________________________________________________
conv2d_11 (Conv2D)              (None, 32, 128, 128) 9248        concatenate_2[0][0]              
__________________________________________________________________________________________________
activation_7 (Activation)       (None, 32, 128, 128) 0           conv2d_11[0][0]                  
__________________________________________________________________________________________________
conv2d_12 (Conv2D)              (None, 32, 128, 128) 9248        activation_7[0][0]               
__________________________________________________________________________________________________
activation_8 (Activation)       (None, 32, 128, 128) 0           conv2d_12[0][0]                  
__________________________________________________________________________________________________
conv2d_13 (Conv2D)              (None, 1, 128, 128)  33          activation_8[0][0]               
__________________________________________________________________________________________________
activation_9 (Activation)       (None, 1, 128, 128)  0           conv2d_13[0][0]                  
==================================================================================================
Total params: 421,313
Trainable params: 421,313
Non-trainable params: 0

Сеть построена и

X_train.shape = (576, 6, 128, 128)
y_train.shape = (576, 1, 128, 128)

Но во время обучения я получаю эту ошибку

Epoch 1/1
---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
........../run_CNN_tau.py in <module>()
    174 U_net.compile(optimizer="RMSProp", loss="mae")#, metrics=["accuracy"])
    175 
--> 176 U_net.fit(X_train, y_train)#, validation_data=(X_test, y_test), epochs=1)
    177 
    178 #model.add(Conv2D(32, kernel_size=(16,16), padding="same", activation='relu', input_shape=input_shape))

/usr/local/lib/python2.7/dist-packages/keras/engine/training.pyc in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, **kwargs)
   1667                               initial_epoch=initial_epoch,
   1668                               steps_per_epoch=steps_per_epoch,
-> 1669                               validation_steps=validation_steps)
   1670 
   1671     def evaluate(self, x=None, y=None,

/usr/local/lib/python2.7/dist-packages/keras/engine/training.pyc in _fit_loop(self, f, ins, out_labels, batch_size, epochs, verbose, callbacks, val_f, val_ins, shuffle, callback_metrics, initial_epoch, steps_per_epoch, validation_steps)
   1204                         ins_batch[i] = ins_batch[i].toarray()
   1205 
-> 1206                     outs = f(ins_batch)
   1207                     if not isinstance(outs, list):
   1208                         outs = [outs]

/usr/local/lib/python2.7/dist-packages/keras/backend/tensorflow_backend.pyc in __call__(self, inputs)
   2473         session = get_session()
   2474         updated = session.run(fetches=fetches, feed_dict=feed_dict,
-> 2475                               **self.session_kwargs)
   2476         return updated[:len(self.outputs)]
   2477 

/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.pyc in run(self, fetches, feed_dict, options, run_metadata)
    887     try:
    888       result = self._run(None, fetches, feed_dict, options_ptr,
--> 889                          run_metadata_ptr)
    890       if run_metadata:
    891         proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)

/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.pyc in _run(self, handle, fetches, feed_dict, options, run_metadata)
   1118     if final_fetches or final_targets or (handle and feed_dict_tensor):
   1119       results = self._do_run(handle, final_targets, final_fetches,
-> 1120                              feed_dict_tensor, options, run_metadata)
   1121     else:
   1122       results = []

/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.pyc in _do_run(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)
   1315     if handle is None:
   1316       return self._do_call(_run_fn, self._session, feeds, fetches, targets,
-> 1317                            options, run_metadata)
   1318     else:
   1319       return self._do_call(_prun_fn, self._session, handle, feeds, fetches)

/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.pyc in _do_call(self, fn, *args)
   1334         except KeyError:
   1335           pass
-> 1336       raise type(e)(node_def, op, message)
   1337 
   1338   def _extend_graph(self):

InvalidArgumentError: Incompatible shapes: [96,1,128,128] vs. [32,1,128,128]
     [[Node: training/RMSprop/gradients/loss/activation_9_loss/sub_grad/BroadcastGradientArgs = BroadcastGradientArgs[T=DT_INT32, _class=["loc:@loss/activation_9_loss/sub"], _device="/job:localhost/replica:0/task:0/device:CPU:0"](training/RMSprop/gradients/loss/activation_9_loss/sub_grad/Shape, training/RMSprop/gradients/loss/activation_9_loss/sub_grad/Shape_1)]]

Caused by op u'training/RMSprop/gradients/loss/activation_9_loss/sub_grad/BroadcastGradientArgs', defined at:
  File "/usr/local/bin/ipython", line 11, in <module>
    sys.exit(start_ipython())
  File "/usr/local/lib/python2.7/dist-packages/IPython/__init__.py", line 119, in start_ipython
    return launch_new_instance(argv=argv, **kwargs)
  File "/usr/local/lib/python2.7/dist-packages/traitlets/config/application.py", line 658, in launch_instance
    app.start()
  File "/usr/local/lib/python2.7/dist-packages/IPython/terminal/ipapp.py", line 355, in start
    self.shell.mainloop()
  File "/usr/local/lib/python2.7/dist-packages/IPython/terminal/interactiveshell.py", line 495, in mainloop
    self.interact()
  File "/usr/local/lib/python2.7/dist-packages/IPython/terminal/interactiveshell.py", line 486, in interact
    self.run_cell(code, store_history=True)
  File "/usr/local/lib/python2.7/dist-packages/IPython/core/interactiveshell.py", line 2714, in run_cell
    interactivity=interactivity, compiler=compiler, result=result)
  File "/usr/local/lib/python2.7/dist-packages/IPython/core/interactiveshell.py", line 2824, in run_ast_nodes
    if self.run_code(code, result):
  File "/usr/local/lib/python2.7/dist-packages/IPython/core/interactiveshell.py", line 2878, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-463-b869a174fafa>", line 1, in <module>
    get_ipython().magic(u'run run_CNN_tau.py')
  File "/usr/local/lib/python2.7/dist-packages/IPython/core/interactiveshell.py", line 2160, in magic
    return self.run_line_magic(magic_name, magic_arg_s)
  File "/usr/local/lib/python2.7/dist-packages/IPython/core/interactiveshell.py", line 2081, in run_line_magic
    result = fn(*args,**kwargs)
  File "<decorator-gen-58>", line 2, in run
  File "/usr/local/lib/python2.7/dist-packages/IPython/core/magic.py", line 188, in <lambda>
    call = lambda f, *a, **k: f(*a, **k)
  File "/usr/local/lib/python2.7/dist-packages/IPython/core/magics/execution.py", line 742, in run
    run()
  File "/usr/local/lib/python2.7/dist-packages/IPython/core/magics/execution.py", line 728, in run
    exit_ignore=exit_ignore)
  File "/usr/local/lib/python2.7/dist-packages/IPython/core/interactiveshell.py", line 2483, in safe_execfile
    self.compile if kw['shell_futures'] else None)
  File "/usr/local/lib/python2.7/dist-packages/IPython/utils/py3compat.py", line 289, in execfile
    builtin_mod.execfile(filename, *where)
  File "/home/nsaura/Documents/Git_RANNS/ML/turbo/wk/tests/python/run_CNN_tau.py", line 176, in <module>
    U_net.fit(X_train, y_train)#, validation_data=(X_test, y_test), epochs=1)
  File "/usr/local/lib/python2.7/dist-packages/keras/engine/training.py", line 1646, in fit
    self._make_train_function()
  File "/usr/local/lib/python2.7/dist-packages/keras/engine/training.py", line 970, in _make_train_function
    loss=self.total_loss)
  File "/usr/local/lib/python2.7/dist-packages/keras/legacy/interfaces.py", line 91, in wrapper
    return func(*args, **kwargs)
  File "/usr/local/lib/python2.7/dist-packages/keras/optimizers.py", line 233, in get_updates
    grads = self.get_gradients(loss, params)
  File "/usr/local/lib/python2.7/dist-packages/keras/optimizers.py", line 78, in get_gradients
    grads = K.gradients(loss, params)
  File "/usr/local/lib/python2.7/dist-packages/keras/backend/tensorflow_backend.py", line 2512, in gradients
    return tf.gradients(loss, variables, colocate_gradients_with_ops=True)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/gradients_impl.py", line 581, in gradients
    grad_scope, op, func_call, lambda: grad_fn(op, *out_grads))
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/gradients_impl.py", line 353, in _MaybeCompile
    return grad_fn()  # Exit early
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/gradients_impl.py", line 581, in <lambda>
    grad_scope, op, func_call, lambda: grad_fn(op, *out_grads))
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/math_grad.py", line 727, in _SubGrad
    rx, ry = gen_array_ops._broadcast_gradient_args(sx, sy)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/gen_array_ops.py", line 532, in _broadcast_gradient_args
    "BroadcastGradientArgs", s0=s0, s1=s1, name=name)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/op_def_library.py", line 787, in _apply_op_helper
    op_def=op_def)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.py", line 2956, in create_op
    op_def=op_def)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.py", line 1470, in __init__
    self._traceback = self._graph._extract_stack()  # pylint: disable=protected-access

...which was originally created as op u'loss/activation_9_loss/sub', defined at:
  File "/usr/local/bin/ipython", line 11, in <module>
    sys.exit(start_ipython())
[elided 16 identical lines from previous traceback]
  File "/usr/local/lib/python2.7/dist-packages/IPython/utils/py3compat.py", line 289, in execfile
    builtin_mod.execfile(filename, *where)
  File "/home/nsaura/Documents/Git_RANNS/ML/turbo/wk/tests/python/run_CNN_tau.py", line 174, in <module>
    U_net.compile(optimizer="RMSProp", loss="mae")#, metrics=["accuracy"])
  File "/usr/local/lib/python2.7/dist-packages/keras/engine/training.py", line 827, in compile
    sample_weight, mask)
  File "/usr/local/lib/python2.7/dist-packages/keras/engine/training.py", line 426, in weighted
    score_array = fn(y_true, y_pred)
  File "/usr/local/lib/python2.7/dist-packages/keras/losses.py", line 18, in mean_absolute_error
    return K.mean(K.abs(y_pred - y_true), axis=-1)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/math_ops.py", line 894, in binary_op_wrapper
    return func(x, y, name=name)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/gen_math_ops.py", line 4636, in _sub
    "Sub", x=x, y=y, name=name)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/op_def_library.py", line 787, in _apply_op_helper
    op_def=op_def)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.py", line 2956, in create_op
    op_def=op_def)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.py", line 1470, in __init__
    self._traceback = self._graph._extract_stack()  # pylint: disable=protected-access

InvalidArgumentError (see above for traceback): Incompatible shapes: [96,1,128,128] vs. [32,1,128,128]
     [[Node: training/RMSprop/gradients/loss/activation_9_loss/sub_grad/BroadcastGradientArgs = BroadcastGradientArgs[T=DT_INT32, _class=["loc:@loss/activation_9_loss/sub"], _device="/job:localhost/replica:0/task:0/device:CPU:0"](training/RMSprop/gradients/loss/activation_9_loss/sub_grad/Shape, training/RMSprop/gradients/loss/activation_9_loss/sub_grad/Shape_1)]]

Дело в том, что эта сеть работает нормально, если слои конкатенации удалены. Может кто-нибудь объяснить, как я могу решить эту проблему?

1 Ответ

1 голос
/ 17 октября 2019

Согласно документации Keras, размер пакета по умолчанию для обучения составляет 32 образца (https://keras.io/models/model/#fit),), и если я смотрю на вашу архитектуру, кажется, что вы по сути берете входной поток, разделяете его, а затем объединяетеэто дважды (один раз для каждого сцепления), получая 96 выборок на пакет. Это может объяснить содержание сообщения об ошибке: «[96,1,128,128] против [32,1,128,128]».

Вы уверены, что хотите бытьделать конкатенацию по размеру пакета? Надеюсь, это поможет.

...