Как вы собираете элементы y_pred, которые не соответствуют истинной метке в пользовательской функции потерь Keras / tf2.0? - PullRequest
0 голосов
/ 22 декабря 2019

Ниже приведен простой пример того, что я хотел бы сделать:

import numpy as np

y_true = np.array([0,0,1])
y_pred = np.array([0.1,0.2,0.7])

yc = (1-y_true).astype('bool')

desired = y_pred[yc]

>>> desired
>>> array([0.1, 0.2])

Итак, предсказание, соответствующее основной истине, равно 0,7, я хочу оперировать массивом, содержащим все элементыy_pred, за исключением основного элемента истины.

Я не уверен, как заставить это работать в Керасе. Вот рабочий пример проблемы в функции потерь. Прямо сейчас «желаемый» ничего не делает, но вот с чем мне нужно поработать:

# using tensorflow 2.0.0 and keras 2.3.1

import tensorflow.keras.backend as K
import tensorflow as tf
from tensorflow.keras.layers import Input,Dense,Flatten
from tensorflow.keras.models import Model
from keras.datasets import mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()

# Normalize data.
x_train = x_train.astype('float32') / 255
x_test = x_test.astype('float32') / 255

# Convert class vectors to binary class matrices.
y_train = tf.keras.utils.to_categorical(y_train, 10)
y_test = tf.keras.utils.to_categorical(y_test, 10)

input_shape = x_train.shape[1:]


x_in = Input((input_shape))

x = Flatten()(x_in)
x = Dense(256,'relu')(x)
x = Dense(256,'relu')(x)
x = Dense(256,'relu')(x)

out = Dense(10,'softmax')(x)




def loss(y_true,y_pred):


    yc = tf.math.logical_not(kb.cast(y_true, 'bool'))
    desired = tf.boolean_mask(y_pred,yc,axis = 1)    #Remove and it runs


    CE = tf.keras.losses.categorical_crossentropy(
        y_true,
        y_pred)

    L = CE

    return L



model = Model(x_in,out)

model.compile('adam',loss = loss,metrics = ['accuracy'])


model.fit(x_train,y_train)

Я получаю сообщение об ошибке

ValueError: Shapes (10,) and (None, None) are incompatible

Где 10 - это числокатегорий. Конечная цель состоит в том, чтобы реализовать это: ComplementEntropy в Керасе, где моя проблема, кажется, строки 26-28.

...