ошибка в форме данных с керами слоя conv2D - PullRequest
0 голосов
/ 20 ноября 2018

Я крупный владелец нейронных сетей, я получаю следующую ошибку при запуске fit,

ValueError: Ошибка при проверке цели: ожидалось, что conv2d_1 будет иметь форму (64, 222, 222), но получилмассив с формой (1, 224, 224)

Я использую изображения в градациях серого, насколько я знаю, я думаю, что я правильно формирую входные данные.я не могу понять, что я делаю неправильно.

вот фрагмент сети

модель сети:

def convLayer(channels):
  return

Conv2D(channels,kernel_size=3,activation='relu',\
 kernel_initializer=initializers.random_normal(mean=0.0, stddev=0.01),\ 
  data_format='channels_first')

class est_net():

  def __init__(self, input=None):
    if input is None:
        input=Input(shape=(1,224,224))
    self.input=input

    conv1_1 = convLayer(64)(self.input)

    self.output = conv1_1
    self.CDECNN = Model(inputs=self.input, outputs=self.output)
    print(self.CDECNN.summary())

чтение данных:

def __iter__(self):
    files=self.img_files
    for f in files:
        if f==".DS_Store":
            continue
        img=cv2.imread(os.path.join(self.img_path,f),cv2.COLOR_BGR2GRAY)
        img=img.reshape(1, img.shape[0], img.shape[1])
        if img is None:
            print("unable to read image %s." % f)
            exit(0)
        gt_file='GT_'+f.split('.')[0]+'.mat'
        gt=sio.loadmat(os.path.join(self.gt_path,gt_file))['d_map']
        gt=gt.reshape(1, gt.shape[0], gt.shape[1])
        yield(img,gt)

обучение:

input_img=[] #representing input segment images to be fed to the network
actual_dgt=[] #representing the actual dot-map ground-truth

for i, (img, dgt) in enumerate(training_DS):
    input_img.append(img)
    actual_dgt.append(dgt)

#initializing training parameters
sgd=optimizers.SGD(lr=0.01, decay=0.0005, momentum=0.9)

#compiling the network and defining the loss method
net.CDECNN.compile(optimizer=sgd, loss='mean_squared_error')

#training CDECNN network on training data
training_log=net.CDECNN.fit(np.array(input_img), np.array(actual_dgt), batch_size=1, epochs=5)
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...