Проблема с Keras batch_size при использовании model.fit в пользовательских слоях - PullRequest
0 голосов
/ 23 февраля 2019

У меня есть пользовательский слой, который изменяет входной тензор, выполняет некоторые точечные произведения с ядрами и возвращает тензор с тем же числом измерений.Вход в мою сеть - изображения, скажем, размером 61x80.Когда число изображений поезда кратно batch_size, model.fit работает нормально.например, общее количество изображений поездов = 2700, batch_size = 10.Но когда общее количество изображений поездов = 2701, batch_size () не работает, выдает ошибку примерно так:

Epoch 1/5 2520/2701 [==========================> ...] - ETA: 0s - потеря: 2.7465 - acc: 0.2516Traceback (последний последний вызов):

Файл "", строка 5, в истории = model.fit (x_train, y_train, batch_size = 10, epochs = 5)

Файл "/home/eee/anaconda3/lib/python3.6/site-packages / keras / engine / training.py ", строка 1039, в форме fitation_steps = validation_steps)

Файл" /home/eee/anaconda3/lib/python3.6/site-packages/keras/engine/training_arrays.py ", строка 199, в fit_loop outs = f (ins_batch)

Файл" /home/eee/anaconda3/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py ", строка 2715, в вызов , возврат self._call (входы)

Файл "/home/eee/anaconda3/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py ", строка 2675, в _call fetched = self._callable_fn (* array_vals)

File" / home / eee / anaconda3 / lib / python3.6 / site-packages / tenorflow / python / client / session.py ", строка 1439, в call run_metadata_ptr)

File" / home / eee/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/errors_impl.py ", строка 528, в выход c_api.TF_GetCode (self.status.status))

InvalidArgumentError: Вход для изменения формы является тензором со значениями 4880, но запрошенная форма имеет 48800 [[{{node my_layer_3 / Reshape}} = Reshape [T = DT_FLOAT, Tshape = DT_INT32, _device = "/ job: localhost / replica: 0 / задача: 0 / устройство: ЦП: 0 "] (_arg_input_2_0_0, my_layer_3 / stack)]]

, пожалуйста, помогите, как обойти эту проблему.

Редактировать: -Добавление кода пользовательских слоев

class MyLayer(Layer):

def __init__(self, output_dim, **kwargs):
    self.output_dim = output_dim
    super(MyLayer, self).__init__(**kwargs)

def build(self, input_shape):
    print(len(input_shape))
    # Create a trainable weight variable for this layer.
    assert len(input_shape) >= 3
    input_dim = input_shape[1:]
    print(input_shape)


    self.kernel1 = self.add_weight(shape=self.output_dim[0],input_dim[0]),
                                   name = 'kernel1',
                                  initializer='uniform',
                                  trainable=True)
    print(self.kernel1)

    self.kernel2 = self.add_weight(shape=self.output_dim[1],input_dim[1]),
                                   name = 'kernel2',
                                  initializer='uniform',
                                  trainable=True)
    print(self.kernel2)


    super(MyLayer, self).build(input_shape)  

def call(self, x):
     print(x.shape)
     input_shape=x.shape

     mat1_shape =K.int_shape(self.kernel1)     
     mat2_shape =K.int_shape(self.kernel2)

     output1 = Myoperation(x,self.kernel1,1)
     output2 = Myoperation(output1,self.kernel2,2)                

     return output2


def compute_output_shape(self, input_shape):

    return (input_shape[0],self.output_dim[0],self.output_dim[1])

Код для функции Myoperation -

def Myoperation(x,mat,mode):
    shape1 = K.shape(x) 
    mode_list =[0,1,2]
    mode_list.remove(mode)
    mode_shape= shape1[mode]

    new_shape = tf.stack( [mode_shape,(shape1[mode_list[0]]*shape1[mode_list[1]])])                    
    input_reshaped = K.reshape(x, new_shape)
    ten_mul = K.dot(mat,input_reshaped)
    out_mode=K.int_shape(mat)
    if (mode==1):
       out_shape =  tf.stack([shape1[mode_list[0]],out_mode[0],shape1[mode_list[1]]])
    if (mode==2):
       out_shape = tf.stack([shape1[mode_list[0]],shape1[mode_list[1]],out_mode[0]])

    output_reshaped = K.reshape(ten_mul,out_shape)    

   return output_reshaped

Проблема заключается в изменении формы тензора, когда набор изображений поезда не кратенразмер партии.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...