Как использовать feed_dict для модели глубокого перезапуска? - PullRequest
0 голосов
/ 24 декабря 2018

Я пытаюсь использовать предоставленную здесь модель глубокого отсчета https://github.com/DrSleep/tensorflow-deeplab-resnet

Я могу успешно обучить модель с моим набором данных.

Также результаты были хорошими.Но в этом репо нет информации для использования модели для динамического прогнозирования

. Поэтому я попытался загрузить график модели и сделать прогноз

Ниже кода, который я использовал для загрузки

img = <img path>
sess = tf.Session()
mod_dir = actual_path(model_dir)
model_file = os.path.join(mod_dir, "snapshots_fine_tune_new_2")
model_file = os.path.join(model_file, "model.ckpt-1400")
loader = tf.train.import_meta_graph(model_file+".meta")
loader.restore(sess, model_file)
placeholder = tf.placeholder(name='data', dtype=tf.float32,
                             shape=[None, None,
                                    3])

# Predictions.
raw_output = sess.graph.get_tensor_by_name('fc1_voc12:0')
raw_output = tf.image.resize_bilinear(raw_output, tf.shape(placeholder)[1:3, 
])

# CRF.
inv_image = tf.py_func(inv_preprocess, [placeholder], tf.uint8)
raw_output = tf.py_func(dense_crf, [tf.nn.softmax(raw_output), inv_image], 
tf.float32)

raw_output = tf.argmax(raw_output, dimension=3)
pred = tf.expand_dims(raw_output, dim=3)  # Create a 4-d tensor.

img = tf.expand_dims(read_image(img), dim=0)

preds = sess.run(pred,feed_dict= 
                         {sess.graph.get_tensor_by_name("data:0"):img})

Но при выполнении я получаю эту ошибку

KeyError: «Имя« data: 0 »относится к Tensor, который не существует. Операция« data »не существуетсуществует на графике. "

Ниже кода модели

class DeepLabResNetModel(Network):
def setup(self, is_training):
    '''Network definition.

    Args:
      is_training: whether to update the running mean and variance of the batch normalisation layer.
                   If the batch size is small, it is better to keep the running mean and variance of 
                   the-pretrained model frozen.
    '''
    (self.feed('data')
         .conv(7, 7, 64, 2, 2, biased=False, relu=False, name='conv1')
         .batch_normalization(is_training=is_training, activation_fn=tf.nn.relu, name='bn_conv1')
         .max_pool(3, 3, 2, 2, name='pool1')
         .conv(1, 1, 256, 1, 1, biased=False, relu=False, name='res2a_branch1')
         .batch_normalization(is_training=is_training, activation_fn=None, name='bn2a_branch1'))

    (self.feed('pool1')
         .conv(1, 1, 64, 1, 1, biased=False, relu=False, name='res2a_branch2a')
         .batch_normalization(is_training=is_training, activation_fn=tf.nn.relu, name='bn2a_branch2a')
         .conv(3, 3, 64, 1, 1, biased=False, relu=False, name='res2a_branch2b')
         .batch_normalization(is_training=is_training, activation_fn=tf.nn.relu, name='bn2a_branch2b')
         .conv(1, 1, 256, 1, 1, biased=False, relu=False, name='res2a_branch2c')
         .batch_normalization(is_training=is_training, activation_fn=None, name='bn2a_branch2c'))

    (self.feed('bn2a_branch1', 
               'bn2a_branch2c')
         .add(name='res2a')
         .relu(name='res2a_relu')
         .conv(1, 1, 64, 1, 1, biased=False, relu=False, name='res2b_branch2a')
         .batch_normalization(is_training=is_training, activation_fn=tf.nn.relu, name='bn2b_branch2a')
         .conv(3, 3, 64, 1, 1, biased=False, relu=False, name='res2b_branch2b')
         .batch_normalization(is_training=is_training, activation_fn=tf.nn.relu, name='bn2b_branch2b')
         .conv(1, 1, 256, 1, 1, biased=False, relu=False, name='res2b_branch2c')
         .batch_normalization(is_training=is_training, activation_fn=None, name='bn2b_branch2c'))

    (self.feed('res2a_relu', 
               'bn2b_branch2c')
         .add(name='res2b')
         .relu(name='res2b_relu')
         .conv(1, 1, 64, 1, 1, biased=False, relu=False, name='res2c_branch2a')
         .batch_normalization(is_training=is_training, activation_fn=tf.nn.relu, name='bn2c_branch2a')
         .conv(3, 3, 64, 1, 1, biased=False, relu=False, name='res2c_branch2b')
         .batch_normalization(is_training=is_training, activation_fn=tf.nn.relu, name='bn2c_branch2b')
         .conv(1, 1, 256, 1, 1, biased=False, relu=False, name='res2c_branch2c')
         .batch_normalization(is_training=is_training, activation_fn=None, name='bn2c_branch2c'))

...

Может ли кто-нибудь помочь мне использовать входные данные динамически?

...