Пользовательская функция потерь Keras: переменная с формой batch_size (y_true) - PullRequest
0 голосов
/ 27 января 2019

При реализации пользовательской функции потерь в Keras мне требуется tf.Variable с формой размера пакета моих входных данных (y_true, y_pred).

def custom_loss(y_true, y_pred):

    counter = tf.Variable(tf.zeros(K.shape(y_true)[0], dtype=tf.float32))
    ...

Однако это приводит к ошибке:

You must feed a value for placeholder tensor 'dense_17_target' with dtype float and shape [?,?]

Если я фиксирую batch_size в значении:

def custom_loss(y_true, y_pred):

    counter = tf.Variable(tf.zeros(batch_size, dtype=tf.float32))
    ...

, так что |training_set| % batch_size и |val_set| % batch_size равны нулю, все работает нормально.

Есть ли предложения, почему не работает присвоение переменной с размером пакета на основе формы ввода (y_true и y_pred)?

РЕШЕНИЕ

Я нашел удовлетворительное решение, которое работает.Я инициализировал переменную с максимально возможным размером batch_size (указанным во время сборки модели) и использовал K.shape(y_true)[0] только для нарезки переменной.Таким образом, это работает отлично.Вот код:

def custom_loss(y_true, y_pred):
    counter = tf.Variable(tf.zeros(batch_size, dtype=tf.float32))
    ...
    true_counter = counter[:K.shape(y_true)[0]]
    ...

Ответы [ 2 ]

0 голосов
/ 20 марта 2019

Альтернативное решение - создать переменную и динамически изменить ее форму, используя tf.assign с validate_shape=False:

counter = tf.Variable(0.0)
...
val = tf.zeros(tf.shape(y_true)[:1], 0.0)
counter = tf.assign(counter, val, validate_shape=False)
0 голосов
/ 27 января 2019

Это не работает, потому что K.shape возвращает вам символическую форму, которая сама является тензором, а не кортежем значений int. Чтобы получить значение от тензора, вы должны оценить его в рамках сеанса. Смотрите документацию для этого. Чтобы получить реальное значение до времени оценки, используйте K.int_shape: https://keras.io/backend/#int_shape

Однако K.int_shape здесь также не сработает, поскольку это просто статические метаданные, которые обычно не отражают текущий размер пакета, но имеют значение-заполнитель None.

Решение, которое вы нашли (контролируйте размер партии и используйте его в потере), действительно хорошее.

Я полагаю, что проблема в том, что вам нужно знать размер пакета во время определения, чтобы построить переменную, но он будет известен только во время сеанса.

Если вы работали с ним как с тензором, все должно быть в порядке, см. пример .

...