Входы теряются, когда я загружаю сохраненную модель в Keras с пользовательской потерей, которая принимает несколько входов - PullRequest
1 голос
/ 08 октября 2019

tf. версия '1.12.0'

У меня есть функция пользовательских потерь, которая принимает несколько входов. Работает нормально, за исключением случаев, когда я пытаюсь сохранить и загрузить модель. Вот простой пример, который показывает, как теряются входные данные. Пожалуйста, посмотрите.

x = tf.keras.Input(shape=(5,), name='input')
y_true = tf.keras.Input(shape=(5,), name='y_true' )
y_pred = tf.keras.layers.Dense(5)(x)
other_data = tf.keras.Input(shape=(5,), name='other_data' )
model = tf.keras.Model(inputs=[x, y_true, other_data],  outputs=y_pred)

def custom_loss(y_true, y_pred):
    return tf.reduce_sum(tf.pow(y_true -y_pred,2)) + tf.reduce_sum(tf.multiply(y_pred,other_data))

model.compile(optimizer=tf.keras.optimizers.Adam(lr=0.01, beta_2=0.99), loss=custom_loss)

data = np.random.rand(5,5)
model.fit([data, data, data], data)

model.save('tmp.h5')
print(model.input_names)
model1 = tf.keras.models.load_model('tmp.h5', custom_objects={'custom_loss':custom_loss})
print(model1.input_names)

model1.fit([data, data, data], data)

Эпоха 1/1 5/5 [=============================] - 0 с 42 мс / шаг - потеря: 9,4392

['input', 'y_true', 'other_data'] <------------- Это хорошо </strong>

['input'] <----------- Что здесь произошло? </strong>

Traceback (большинствопоследний вызов последний):

Файл "", строка 21, в model1.fit ([данные, данные, данные], данные)

Файл "C: \ src \ Anaconda3 \ envs \"deepema \ lib \ site-packages \ tenorflow \ python \ keras \ engine \ training.py ", строка 1536, в форме validation_split = validation_split)

Файл" C: \ src \ Anaconda3 \ envs \ deepema \ lib\ site-packages \ tenorflow \ python \ keras \ engine \ training.py ", строка 992, в _standardize_user_data class_weight, batch_size)

Файл" C: \ src \ Anaconda3 \ envs \ deepema \ lib \ site- "packages \ tenorflow \ python \ keras \ engine \ training.py ", строка 1117, в _standardize_weights exception_prefix = 'input')

Файл" C: \ src \ Anaconda3 \ envs \ deepema \ lib \ site-packages "\ tensorflow \ питон \ keras \ двигатель \ трaining_utils.py ", строка 293, в standardize_input_data str (len (data)) + 'arrays:' + str (data) [: 200] + '...')

ValueError: Ошибка при проверке моделивход: список массивов Numpy, которые вы передаете своей модели, не соответствует размеру, который ожидала модель. Ожидается увидеть 1 массив (ов), но вместо этого получен следующий список из 3 массивов: [array ([[0.12768201, 0.06106967, 0.99779087, 0.50767692, 0.21839594]], [0.82444334, 0.1367274, 0.14495117, 0.32396153, 0.24457874], 0.298700,40644681, 0,69308081, 0,30091417, 0,776 ...

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...