Это ошибка в функции save_weights / load_weights с пользовательским слоем в tenorflow 1.13.1? - PullRequest
1 голос
/ 24 марта 2019

Это заняло у меня больше, чем день, так разочарован.Я сомневаюсь, что это ошибка в Tensorflow 1.13.1 (стабильная версия).

Итак, я создал собственную модель в стиле подклассов моделей, которая содержала только 1 пользовательский слой.После инициализации я записал его обучаемые веса в файл и восстановил его, используя функции save_weights и load_weights.Веса до и после сохранения были различны.

Я также провел тот же тест на Tensorflow 2.0.0a0, и оказалось, что эта версия не получила этого явления.

Мой пользовательский слой:

class EncodingLayer(tf.keras.layers.Layer):
    def __init__(self, out_size):
        super().__init__()
        self.rnn_layer = tf.keras.layers.GRU(out_size, return_sequences=True, return_state=True, recurrent_initializer='glorot_uniform')

    def call(self, X, **kwargs):
        output, state = self.rnn_layer(X)
        return output, state

Это основная часть:

class EncodingModel(tf.keras.Model):

    def __init__(self):
        super().__init__()
        self.encoder_layer = EncodingLayer(out_size=1)

    def infer(self, inputs):
        output, state = self.encoder_layer(inputs)
        return output


if __name__ == '__main__':
    # Comment line below for running in TF 2.0
    tf.enable_eager_execution()

    # shape == (2, 3, 2)
    inputs = tf.convert_to_tensor([
        [[1., 2.], [2., 3.], [4., 4.]],
        [[1., 2.], [2., 3.], [4., 4.]],
    ])

    model = EncodingModel()

    # Just for building the graph
    model.infer(inputs)

    print('Before saving model: ', model.trainable_weights[0].numpy().mean())
    model.save_weights('weight')

    new_model = EncodingModel()
    new_model.infer(inputs)
    new_model.load_weights('weight')
    print('Loaded model: ', new_model.trainable_weights[0].numpy().mean())

Результат при работе в TF 1.13.1:

Before saving model:  0.28864467
Loaded model:  0.117300846

Результат при работе в TF 2.0.0a0:

Before saving model:  -0.06922924
Loaded model:  -0.06922924

Хотя результат предполагает, что это может быть ошибкой, я не был в этом уверен.Так как код очень прост, если такая ошибка существует, ее легко обнаружить.Я много искал, но не нашел ни одного упоминания об этом.Таким образом, я предполагаю, что есть кое-что, что я неправильно понял:)

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...