Пользовательская функция потерь Keras для категориальных данных с двоичным кодированием (не с горячим кодированием) - PullRequest
3 голосов
/ 18 апреля 2019

Мне нужна помощь в написании пользовательской функции потерь / метрики для Keras. Мои категории в двоичном коде (не один горячий). Я хотел бы сделать побитовое сравнение между реальными классами и предсказанными классами.

Например, Реальная метка: 0x1111111111 Предсказанная метка: 0x1011101111

Предсказанная метка имеет 8 из 10 правильных битов, поэтому точность этого соответствия должна быть 0,8, а не 0,0. Я понятия не имею, как я могу это сделать с помощью команд Keras.

РЕДАКТИРОВАТЬ 1: В настоящее время я использую что-то вроде этого, но это еще не работает:

def custom_binary_error(y_true, y_pred, n=11):
    diff_dec = K.tf.bitwise.bitwise_xor(K.tf.cast(y_true, K.tf.int32), K.tf.cast(y_pred, K.tf.int32))
    diff_bin = K.tf.mod(K.tf.bitwise.right_shift(K.tf.expand_dims(diff_dec,1), K.tf.range(n)), 2)
    diff_sum = K.tf.math.reduce_sum(diff_bin, 1)
    diff_percent = K.tf.math.divide(diff_sum, 11)
    return K.tf.math.reduce_mean(diff_percent, 0)

Я получаю эту ошибку:

ValueError: Dimensions must be equal, but are 2048 and 11 for 'loss/activation_1_loss/RightShift' (op: 'RightShift') with input shapes: [?,1,2048], [11].

Ответы [ 2 ]

0 голосов
/ 29 апреля 2019

Вот как вы можете определить свою ошибку:

import tensorflow as tf

def custom_binary_error(y_true, y_pred):
    y_true = tf.cast(y_true, tf.bool)
    y_pred = tf.cast(y_pred, tf.bool)
    xored = tf.logical_xor(y_true, y_pred)
    notxored = tf.logical_not(xored)
    sum_xored = tf.reduce_sum(tf.cast(xored, tf.float32))
    sum_notxored = tf.reduce_sum(tf.cast(notxored, tf.float32))
    return sum_xored / (sum_xored + sum_notxored)

Тестирование с двумя метками размером 6:

import tensorflow as tf

y_train_size = 6

y_train = [[1, 1, 1, 1, 1, 1], [0, 0, 0, 0, 0, 0]]
y_pred = tf.convert_to_tensor([[1, 1, 1, 1, 0, 0], [0, 0, 0, 0, 1, 0]])
y = tf.placeholder(tf.int32, shape=(None, y_train_size))
error = custom_binary_error(y, y_pred)
with tf.Session() as sess:
    res = sess.run(error, feed_dict={y:y_train})
    print(res) # 0.25

Использование в Keras:

import tensorflow as tf
import numpy as np

y_train_size = 6

def custom_binary_error(y_true, y_pred):
    y_true = tf.cast(y_true, tf.bool)
    y_pred = tf.cast(y_pred, tf.bool)
    xored = tf.logical_xor(y_true, y_pred)
    notxored = tf.logical_not(xored)
    sum_xored = tf.reduce_sum(tf.cast(xored, tf.float32))
    sum_notxored = tf.reduce_sum(tf.cast(notxored, tf.float32))
    return sum_xored / (sum_xored + sum_notxored)

model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Dense(y_train_size))

model.compile(optimizer=tf.keras.optimizers.SGD(0.01),
              loss=[tf.keras.losses.MeanAbsoluteError()],
              metrics=[custom_binary_error])

y_train = np.array([[1, 1, 1, 1, 1, 1], [0, 0, 0, 0, 0, 0]])
x_train = np.random.normal(size=(2, 2))

model.fit(x_train, y_train, epochs=2)

приведет к:

Epoch 1/2
2/2 [==============================] - 0s 23ms/sample - loss: 1.4097 - custom_binary_error: 0.5000
Epoch 2/2
2/2 [==============================] - 0s 328us/sample - loss: 1.4017 - custom_binary_error: 0.5000

Примечание

Если вы хотите точность вместо ошибка функция custom_binary_error() должна вернуть

sum_notxored / (sum_xored + sum_notxored)
0 голосов
/ 19 апреля 2019

Я пытаюсь что-то с предположением, что y_true, y_pred являются положительными целыми числами.

def custom_binary_error(y_true, y_pred):
    width = y_true.bit_length() if y_true.bit_length() > y_pred.bit_length() else y_pred.bit_length()       # finds the greater width of bit sequence, not sure if needed
    diff = np.bitwise_xor(y_true, y_pred)       # 1 when different, 0 when same
    error = np.binary_repr(diff, width=width).count('1')/width       # calculate % of '1's
    return K.variable(error)

Используйте 1-error для точности.Я не проверял это;это просто для того, чтобы дать представление.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...