Как изменить элементы объекта Keras Tensor - PullRequest
1 голос
/ 19 апреля 2020

Я строю нейронную сеть Convolution в Керасе, которая получает пакет изображений с размерами (Нет, 256, 256, 1), и выходные данные будут пакетами с размером (Нет, 256, 256, 3). Теперь после окончательного вывода слоя я хочу добавить слой, который присваивает значения некоторым пикселям в выходном слое на основе условия значения на входах. Вот что я попробовал:

Функция

def SetBoundaries(ins):
    xi = ins[0]
    xo = ins[1]

    bnds = np.where(xi[:, :, :, 0] == 0)
    bnds_s, bnds_i, bnds_j = bnds[0], bnds[1], bnds[2]
    xo[bnds_s, bnds_i, bnds_j, 0] = 0
    xo[bnds_s, bnds_i, bnds_j, 1] = 0
    xo[bnds_s, bnds_i, bnds_j, 2] = 0

    return xo

Модель Keras

def conv_res(inputs):
    x0 = inputs

    ...

    xc = conv_layer(xc, kernel_size=3, stride=1,
                    num_filters=3, name="Final_Conv")

    # apply assignment function
    xc = Lambda(SetBoundaries, name="assign_boundaries")([x0, xc])
    return xc

Наконец, модель построен с использованием

def build_model(inputs):
    xres = int(inputs.shape[1])
    yres = int(inputs.shape[2])
    cres = int(inputs.shape[3])

    inputs = Input((xres, yres, cres))
    outputs = UNet.conv_res(inputs)
    model = keras.Model(inputs=inputs, outputs=outputs)
    return model

Однако при запуске я получаю сообщение об ошибке:

NotImplementedError: Cannot convert a symbolic Tensor (assign_boundaries/Equal:0) to a numpy array.

Все прекрасно работает без функции Lambda . Я понимаю, что проблема заключается в присвоении значения объекту Tensor, но как мне добиться того, что я хочу?

Спасибо

Ответы [ 2 ]

0 голосов
/ 20 апреля 2020

Мне удалось заставить его работать, изменив функцию на:

def SetBoundaries(ins):
    xi = ins[0]
    xo = ins[1]

    xin = tf.broadcast_to(xi, tf.shape(xo))
    mask = K.cast(tf.not_equal(xin, 0), dtype="float32")
    xf = layers.Multiply()([mask, xo])

    return xf
0 голосов
/ 20 апреля 2020

np.where работает с NumPy массивами, но вывод вашей модели - тензор Tensorflow. Попробуйте использовать tf.where , что то же самое, но для tf.Tensor s.

...