Ошибка получения несовместимых фигур с batch_size> 1 с керасом - PullRequest
0 голосов
/ 19 сентября 2018

Я недавно начал работать с keras и посмотрел на все доступные решения, но они не работали.У меня есть простая нейронная сеть для вычисления значения тензора из 5 входных констант и 10 входных тензоров. Это моя нейронная сеть, входные данные основного входного слоя 5 констант и входные данные тензорного входного слоя 10 тензоров:

This is my neural network, the main input layer inputs 5 constants and tensor input layer inputs 10 tensors.

import numpy as np
from keras.utils import plot_model
from keras.models import Model
from keras.layers import Input,Lambda
from keras.layers import Dense
from keras import backend as K


def function(x): #This function is used for the last layer to compute Anisotropic R-S
    tensor = x[0]
    constants = x[1]
    a = K.zeros(shape=(3,3,))

    for i in range(10):
        a = a + constants[:,i]*tensor[:,:,:,i]
    return a


main_input = Input(shape = (5,),name = 'main_input') #The invariant inputs
hidden1 = Dense(10,activation = 'relu')(main_input)
hidden2 = Dense(10,activation= 'relu')(hidden1) #10 constants

tensor_input = Input(shape= (3,3,10,),name = 'tensor_input')

output_layer = Lambda(function)([tensor_input,hidden2])

model = Model(inputs = [main_input,tensor_input], outputs = output_layer)
print(model.summary())

model.compile(optimizer = 'adam', loss='mean_squared_error', metrics=['accuracy'])
plot_model(model, to_file='multilayer_perceptron_graph.png')

#Just test inputs and outputs to correct shape
I_test = np.ones((120,5))
T_test = np.ones((120,3,3,10))
a_test = np.ones((120,3,3))

model.fit({'main_input': I_test, 'tensor_input': T_test},a_test,epochs=50,batch_size=2)

Я использовал лямбда-слой в качестве выходного слоя.Он вычисляет тензор a как: a = g1 * T1 + g2 * T2 + .... g10 * T10, где g - это константы, которые вычисляются из слоя Dense 24.Здесь нет активации, ее простая линейная комбинация.Выход a (3,3) матрица.Входными тензорами являются 10 3 * 3 тензоров, поэтому форма I равна (129,3,3,10).

Я получаю следующую ошибку, когда размер пакета> 1:

InvalidArgumentError: Incompatible shapes: [2] vs. [2,3,3]
     [[Node: training_7/Adam/gradients/lambda_12/mul_9_grad/BroadcastGradientArgs = BroadcastGradientArgs[T=DT_INT32, _class=["loc:@training_7/Adam/gradients/lambda_12/mul_9_grad/Reshape"], _device="/job:localhost/replica:0/task:0/device:CPU:0"](training_7/Adam/gradients/lambda_12/mul_9_grad/Shape, training_7/Adam/gradients/lambda_12/mul_9_grad/Shape_1)]]

Пожалуйста, помогите мне решить эту проблему.

...