Вы можете изменить функцию потерь на то, что умножает значения потерь на соответствующие веса в вашей матрице.
Итак, в качестве примера рассмотрим пример тензорного потока mnist :
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10, activation='softmax')
])
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
, если мы хотим изменить это, чтобы взвесить потери на основе следующей матрицы:
weights = tf.constant([
[1., 1.2, 1., 1., 1., 1., 1., 1., 1., 1.],
[1., 1., 1.2, 1., 1., 1., 1., 1., 1., 1.],
[1., 1., 10.9, 1.2, 1., 1., 1., 1., 1., 1.],
[1., 0.9, 1., 1., 1., 1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]])
тогда мы можем обернуть существующий sparse_categorical_crossentropy
в новую пользовательскую функцию потерь, которая умножает потери на соответствующий вес. Примерно так:
def custom_loss(y_true, y_pred):
# get the prediction from the final softmax layer:
pred_idx = tf.argmax(y_pred, axis=1, output_type=tf.int32)
# stack these so we have a tensor of [[predicted_i, actual_i], ...,] for each i in batch
indices = tf.stack([tf.reshape(pred_idx, (-1,)),
tf.reshape(tf.cast( y_true, tf.int32), (-1,))
], axis=1)
# use tf.gather_nd() to convert indices to the appropriate weight from our matrix [w_i, ...] for each i in batch
batch_weights = tf.gather_nd(weights, indices)
return batch_weights * tf.keras.losses.sparse_categorical_crossentropy(y_true, y_pred)
Затем мы можем использовать эту новую пользовательскую функцию потерь в модели:
model.compile(optimizer='adam',
loss=custom_loss,
metrics=['accuracy'])