Я реализовал 3D CNN
в Керасе с TensorFlow, до которого работал очень хорошо. Теперь, чтобы ускорить обучение на нескольких GPU, я хотел попробовать MXNet
с Keras
. Я ожидал, что мне не придется менять большую часть кода, кроме проблемы «channels_last
» на «channel_first», но программа вылетает при операции Conv3D.
Файл keras.json есть, поэтому он должен быть настроен для нормальной работы:
{
"backend": "mxnet",
"image_data_format": "channels_first",
"epsilon": 1e-07,
"floatx": "float32"
}
Это небольшая часть, которая показывает ошибку:
from keras.models import *
from keras.layers import *
from keras.optimizers import *
def SimpleInceptionBlock(input, num_kernels, kernel_init='he_normal', padding='same', bn_axis=1):
tower1 = Conv3D(num_kernels, 1, padding=padding, kernel_initializer=kernel_init)(input)
tower1 = BatchNormalization()(tower1)
tower1 = ELU()(tower1)
tower2 = MaxPooling3D(pool_size=(2, 2, 2), strides=(1, 1, 1), padding=padding)(input)
tower2 = Conv3D(num_kernels, 1, padding=padding, kernel_initializer=kernel_init)(tower2)
tower2 = BatchNormalization()(tower2)
tower2 = ELU()(tower2)
output = concatenate([tower1, tower2], axis=bn_axis)
return output
def TestNet(input_size=(1,64,64,64), num_class=7):
bn_axis = 1
img_input = Input(shape=input_size)
filter1 = SimpleInceptionBlock(img_input, 16)
# this runs fine, filter1.shape = (None, 32, 64, 64, 64)
filter2 = SimpleInceptionBlock(filter1, 16)
output = Conv3D(num_class, (1, 1, 1), activation='softmax', kernel_initializer = kernel_init, padding='same', kernel_regularizer=l2(1e-4))(filter2)
model = Model(input=img_input, output=output)
return model
model = TestNet()
Первый вызов "SimpleInceptionBlock
" выполняется нормально, с filter1.shape = (None, 32, 64, 64)
, как и ожидалось, но второй вызов выдает сообщение об ошибке:
Ошибка в операторе concat0: [15:40:58]
C: \ Дженкинс \ рабочее пространство \ mxnet-тег \ mxnet \ SRC \ оператор \ пп \ concat.cc: 66:
Проверка не удалась: shape_assign (& (* in_shape) [i], dshape) Несовместимый ввод
форма: ожидаемая [0,0,64,64,64], полученная [0,16,64,65,65]