Я пытаюсь использовать предварительно обученную модель TF Slim ResNet-v1-101 для инициализации двух входных ветвей модели с разными входами (RGB features[0]
и карта глубины features[1]
соответственно) для дальнейшего их объединения и подачи их совместный результат для некоторых пространственных слоев свертки. Чтобы добиться этого, я создаю отдельную область видимости для каждой ветви и инициализирую модели с предварительно обученной контрольной точкой, исключая последний полностью подключенный слой (здесь он называется logits
):
with tf.variable_scope("rgb_branch"):
with tf.contrib.slim.arg_scope(resnet_v1.resnet_arg_scope()):
rgb_logits, end_points = resnet_v1.resnet_v1_101(features[0], self.num_classes, is_training=is_training)
rgb_variables_to_restore = tf.contrib.slim.get_variables_to_restore(exclude=['rgb_branch/resnet_v1_101/logits'])
# strip scope name
rgb_assignment_map = { rgb_variables_to_restore[0].name.split(':')[0] : rgb_variables_to_restore[0]}
rgb_assignment_map.update({ v.name.split(':')[0].split('/', 1)[1] : v for v in rgb_variables_to_restore[1:] })
tf.train.init_from_checkpoint(self.pre_trained_model_path, assignment_map)
with tf.variable_scope("depth_branch"):
with tf.contrib.slim.arg_scope(resnet_v1.resnet_arg_scope()):
depth_logits, end_points = resnet_v1.resnet_v1_101(features[1], self.num_classes, is_training=is_training)
depth_variables_to_restore = tf.contrib.slim.get_variables_to_restore(exclude=['depth_branch/resnet_v1_101/logits'])
depth_assignment_map = { depth_variables_to_restore[0].name.split(':')[0] : depth_variables_to_restore[0]}
depth_assignment_map.update({ v.name.split(':')[0].split('/', 1)[1] : v for v in depth_variables_to_restore[1:] })
tf.train.init_from_checkpoint(self.pre_trained_model_path, assignment_map)
Проблема в том, что во время второй инициализации (ветвь depth
) TF жалуется на форму слоя logits
, как будто он никогда не удалялся:
ValueError: Shape of variable rgb_branch/resnet_v1_101/logits/biases:0 ((3,)) doesn't match with shape of tensor resnet_v1_101/logits/biases ([1000]) from checkpoint reader.
Эта проблема не возникает, когда я инициализирую только одну ветвь, а связана только с первой из определенных ветвей - если я изменю порядок ветвей, вышеприведенная ошибка изменится на depth_branch/resnet_v1_101/logits/biases:0
Я что-то не так делаю? Карты назначений, кажется, отображают одинаковые имена из контрольной точки в переменные из отдельных графиков.
Есть ли другой способ добиться этого с помощью TF Slim?
Спасибо!