Это заняло у меня больше, чем день, так разочарован.Я сомневаюсь, что это ошибка в Tensorflow 1.13.1
(стабильная версия).
Итак, я создал собственную модель в стиле подклассов моделей, которая содержала только 1 пользовательский слой.После инициализации я записал его обучаемые веса в файл и восстановил его, используя функции save_weights и load_weights.Веса до и после сохранения были различны.
Я также провел тот же тест на Tensorflow 2.0.0a0
, и оказалось, что эта версия не получила этого явления.
Мой пользовательский слой:
class EncodingLayer(tf.keras.layers.Layer):
def __init__(self, out_size):
super().__init__()
self.rnn_layer = tf.keras.layers.GRU(out_size, return_sequences=True, return_state=True, recurrent_initializer='glorot_uniform')
def call(self, X, **kwargs):
output, state = self.rnn_layer(X)
return output, state
Это основная часть:
class EncodingModel(tf.keras.Model):
def __init__(self):
super().__init__()
self.encoder_layer = EncodingLayer(out_size=1)
def infer(self, inputs):
output, state = self.encoder_layer(inputs)
return output
if __name__ == '__main__':
# Comment line below for running in TF 2.0
tf.enable_eager_execution()
# shape == (2, 3, 2)
inputs = tf.convert_to_tensor([
[[1., 2.], [2., 3.], [4., 4.]],
[[1., 2.], [2., 3.], [4., 4.]],
])
model = EncodingModel()
# Just for building the graph
model.infer(inputs)
print('Before saving model: ', model.trainable_weights[0].numpy().mean())
model.save_weights('weight')
new_model = EncodingModel()
new_model.infer(inputs)
new_model.load_weights('weight')
print('Loaded model: ', new_model.trainable_weights[0].numpy().mean())
Результат при работе в TF 1.13.1:
Before saving model: 0.28864467
Loaded model: 0.117300846
Результат при работе в TF 2.0.0a0:
Before saving model: -0.06922924
Loaded model: -0.06922924
Хотя результат предполагает, что это может быть ошибкой, я не был в этом уверен.Так как код очень прост, если такая ошибка существует, ее легко обнаружить.Я много искал, но не нашел ни одного упоминания об этом.Таким образом, я предполагаю, что есть кое-что, что я неправильно понял:)