Керас: Как сравнить тензор с целым числом внутри лямбда-слоя? - PullRequest
0 голосов
/ 12 февраля 2020

Я кормлю свой лямбда-слой двумя тензорами. Один из них - тензор (batch_size, 14, 14, 10) из предыдущего сверточного слоя. Другой - тензор (batch_size, 2), обозначающий метки классов для соответствующих выборок в пакете (например, [1. 0.] обозначает класс 0, а [0. 1.] обозначает класс 1).

Для каждого образца в пакете мне нужно, чтобы мой лямбда-слой заполнил первые 5 карты объектов для образца в тензоре (batch_size, 14, 14, 10 ) нулями если образец относится к классу 1. Аналогично, последние 5 карты объектов должны быть заполнены нулями , если образец относится к классу 0.

def my_lambda_layer(inputs):
input_tensor = inputs[0] #shape: (batch_size, 14, 14, 10)
class_targets = inputs[1] #shape: (batch_size, 2)
input_tensor_shape = input_tensor.shape.as_list()

#shape: (batch_size, 14, 14, 10)
filter_ = numpy.ones(shape=(batch_size, input_tensor_shape[1],
                         input_tensor_shape[2], input_tensor_shape[3]))

for i in list(range(0, batch_size)):
    label = K.switch(K.equal(class_targets[i, 0], 1), 0, 1)
    if label == 1:
        filter_[i, :, :, 0:5] *= 0 #Fill first 5 feature maps with zeros.
    elif label == 0:
        filter_[i, :, :, 5:10] *= 0 #Fill last 5 feature maps with zeros.

input_tensor *= filter_
return input_tensor

Полагаю, проблема в том, что операторы if никогда не будут выполнены, поскольку я сравниваю тензор с целым числом:

    if label == 1:
        filter_[i, :, :, 0:5] *= 0
    elif label == 0:
        filter_[i, :, :, 5:10] *= 0

keras.backend.eval (label), очевидно, победил ' Это не работает, так как он отображает ошибки при использовании внутри лямбда-слоя. Как я могу обновить элементы карт объектов на основе метки класса для каждого образца?

...