Загрузка пользовательского слоя CT C из файла h5 в Keras - PullRequest
0 голосов
/ 10 июля 2020

У меня есть такой класс CTCLayer:

class CTCLayer(layers.Layer):
def __init__(self, name=None):
    super().__init__(name=name)
    self.loss_fn = keras.backend.ctc_batch_cost


def call(self, y_true, y_pred):
    # Compute the training-time loss value and add it
    # to the layer using `self.add_loss()`.
    batch_len = tf.cast(tf.shape(y_true)[0], dtype="int64")
    input_length = tf.cast(tf.shape(y_pred)[1], dtype="int64")
    label_length = tf.cast(tf.shape(y_true)[1], dtype="int64")

    input_length = input_length * tf.ones(shape=(batch_len, 1), dtype="int64")
    label_length = label_length * tf.ones(shape=(batch_len, 1), dtype="int64")

    loss = self.loss_fn(y_true, y_pred, input_length, label_length)
    self.add_loss(loss)

    # At test time, just return the computed predictions
    return y_pred

Я обучил свою модель, сохранил ее в файле model.h5 и загрузил через:

model_load = tf.keras.models.load_model('model.h5', custom_objects={'CTCLayer': CTCLayer})

Это бросает init () получил неожиданный аргумент ключевого слова 'trainable' error.

Поскольку я не хочу снова обучать свою модель (временные ограничения), есть ли обходной путь, который я могу сделать, чтобы загрузить модель, не добавляя get_config () в класс CTCLayer?

А если нет, как мне изменить get_config () в классе?

...