Tensorflow вычисляет неверные потери для моделей `tf.keras` при использовании весов - PullRequest
0 голосов
/ 15 ноября 2018

Неправильный расчет потерь при работе с tf.keras. После построения модели tf.keras.fit_generator должен принять (inputs, targets, sample_weights) в качестве входных данных. Однако, если я умножу sample_weights на 10000, потери не изменятся.

Ошибка, по-видимому, появляется в версии Tensorflow 1.10 и выше, например (1.11, 1.12)

Код для воспроизведения

import numpy as np
import tensorflow as tf

WEIGHT_VARIABLE = 1

no_of_features = 10
timesteps = 3
batch_size = 32

def data_gen():

    while True:
        numerical = np.random.randint(5, size=(batch_size, timesteps, no_of_features))
        y = np.random.randint(2, size=batch_size)
        w = np.ones(batch_size) * WEIGHT_VARIABLE

        yield {'numeric_input': numerical}, y, w


def build_model():
    numerical_input = tf.keras.layers.Input(shape=(timesteps, no_of_features), name='numeric_input')
    rnn_out = tf.keras.layers.GRU(32, return_sequences=False)(numerical_input)
    dense = tf.keras.layers.Dense(1, activation='sigmoid', name='main_output')(rnn_out)

    model = tf.keras.models.Model(numerical_input, dense)

    params = {
        'loss': 'binary_crossentropy',
        'optimizer': tf.keras.optimizers.Adam(),
        'metrics': [tf.keras.metrics.binary_crossentropy, tf.keras.metrics.binary_accuracy]
    }
    model.compile(**params)

    return model


def train_model():
    gen1 = data_gen()
    model = build_model()

    model.fit_generator(gen1, epochs=30, steps_per_epoch=10)


if __name__ == '__main__':
    train_model()

В приведенном выше коде вам просто нужно изменить WEIGHT_VARIABLE = 1 с 1 на 100000 и перезапустить файл.


Журналы

v1.10

WEIGHT_VARIABLE = 1 

Epoch 1/5 10/10 [==============================] - 
1s 128ms/step - loss: 0.7407 - binary_crossentropy: 0.7407 - binary_accuracy: 0.5031 
Epoch 2/5 10/10 [==============================] - 
0s 4ms/step - loss: 0.7043 - binary_crossentropy: 0.7043 - binary_accuracy: 0.5125
Epoch 3/5 10/10 [==============================] - 
0s 4ms/step - loss: 0.7055 - binary_crossentropy: 0.7055 - binary_accuracy: 0.5219
Epoch 4/5 10/10 [==============================] - 
0s 4ms/step - loss: 0.7002 - binary_crossentropy: 0.7002 - binary_accuracy: 0.5250
Epoch 5/5 10/10 [==============================] - 
0s 4ms/step - loss: 0.6944 - binary_crossentropy: 0.6944 - binary_accuracy: 0.5375

WEIGHT_VARIABLE = 10000

Epoch 1/5 10/10 [==============================] - 
1s 131ms/step - loss: 7235.5976 - binary_crossentropy: 0.7236 - binary_accuracy: 0.4562 
Epoch 2/5 10/10 [==============================] - 
0s 4ms/step - loss: 7271.9184 - binary_crossentropy: 0.7272 - binary_accuracy: 0.4844 
Epoch 3/5 10/10 [==============================] - 
0s 4ms/step - loss: 7276.9147 - binary_crossentropy: 0.7277 - binary_accuracy: 0.4500 
Epoch 4/5 10/10 [==============================] - 
0s 4ms/step - loss: 7052.0121 - binary_crossentropy: 0.7052 - binary_accuracy: 0.4625 
Epoch 5/5 10/10 [==============================] - 
0s 4ms/step - loss: 7187.0285 - binary_crossentropy: 0.7187 - binary_accuracy: 0.4969

v1.12

WEIGHT_VARIABLE = 1 

Epoch 1/5 10/10 [==============================] - 
1s 68ms/step - loss: 0.7188 - binary_crossentropy: 0.7188 - binary_accuracy: 0.5312 
Epoch 2/5 10/10 [==============================] - 
0s 4ms/step - loss: 0.7044 - binary_crossentropy: 0.7044 - binary_accuracy: 0.4969 
Epoch 3/5 10/10 [==============================] - 
0s 4ms/step - loss: 0.7086 - binary_crossentropy: 0.7086 - binary_accuracy: 0.4844 
Epoch 4/5 10/10 [==============================] - 
0s 4ms/step - loss: 0.7075 - binary_crossentropy: 0.7075 - binary_accuracy: 0.4500 
Epoch 5/5 10/10 [==============================] - 
0s 4ms/step - loss: 0.6950 - binary_crossentropy: 0.6950 - binary_accuracy: 0.5187

WEIGHT_VARIABLE = 10000

Epoch 1/5 10/10 [==============================] - 
1s 74ms/step - loss: 0.9084 - binary_crossentropy: 0.9084 - binary_accuracy: 0.4719
Epoch 2/5 10/10 [==============================] - 
0s 4ms/step - loss: 0.7120 - binary_crossentropy: 0.7120 - binary_accuracy: 0.5062 
Epoch 3/5 10/10 [==============================] - 
0s 4ms/step - loss: 0.7024 - binary_crossentropy: 0.7024 - binary_accuracy: 0.5344
Epoch 4/5 10/10 [==============================] - 
0s 4ms/step - loss: 0.7257 - binary_crossentropy: 0.7257 - binary_accuracy: 0.4500 
Epoch 5/5 10/10 [==============================] - 
0s 4ms/step - loss: 0.7013 - binary_crossentropy: 0.7013 - binary_accuracy: 0.4844

Ссылка на выпуск Github

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...