Я использую tf.image.resize_bilinear в сети сегментации. Кажется, эта функция не поддерживается моделью с несколькими GPU. В следующем коде показана упрощенная ситуация: (которая может быть запущена напрямую)
import os
os.environ["CUDA_VISIBLE_DEVICES"] = '0, 1'
from keras.backend.tensorflow_backend import set_session
from keras import backend as K
from keras.utils import multi_gpu_model
from keras.applications.mobilenet_v2 import preprocess_input
import tensorflow as tf
import numpy as np
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
config.allow_soft_placement = True
sess = tf.Session(config=config)
set_session(sess)
batch = 4
num_classes = 2
size = 128
K.clear_session()
def _GetRandomImg():
shape = (batch, size, size, 3)
img = np.random.randint(low=0, high=256, size=shape)
return preprocess_input(img)
def _GetRandomLabel():
shape = (batch, size, size, num_classes)
label = np.random.randint(low=0, high=num_classes, size=shape)
label = np.exp(label)
label = label/ np.sum(label, axis=-1, keepdims=True)
return label
def DataGen():
while True:
x = _GetRandomImg()
y = _GetRandomLabel()
yield x, y
from keras.layers import Input, Conv2D, Lambda
from keras import Model
def GetModel():
inputs = Input(shape=(size, size, 3))
f = lambda x: tf.image.resize_bilinear(inputs, (size, size), align_corners=True)
x = Lambda(f, output_shape=(size, size, 3))(inputs)
outputs = Conv2D(num_classes, kernel_size=3, padding='same')(x)
model = Model(inputs=[inputs], outputs=[outputs])
return model
gen = DataGen()
with tf.device('/cpu:0'):
model = GetModel()
model = multi_gpu_model(model, gpus=2)
model.compile(loss='categorical_crossentropy', optimizer='sgd')
result = model.fit_generator(gen, epochs=2, verbose = 1, steps_per_epoch = 100)
она отлично работает в среде с одним графическим процессором, но в среде с несколькими графическими процессорами возникла следующая ошибка:
InvalidArgumentError: Incompatible shapes: [3,128,128,2] vs. [6,128,128,2]
[[{{node loss/conv2d_1_loss/categorical_crossentropy/mul}}]]
[[{{node training/SGD/gradients/conv2d_1_1/concat_grad/Slice_1}}]]