TypeError: аргумент Fetch None имеет недопустимый тип . Проблема тензорного потока - PullRequest
0 голосов
/ 05 августа 2020

Я пытаюсь обучить cente rnet с помощью чистой реализации тензорного потока, но при попытке запустить файл python возникает проблема, которую я не могу решить. Это репо с исходным кодом: https://github.com/MioChiu/TF_CenterNet

Это код:

        sess.run(tf.compat.v1.global_variables_initializer())
        if cfg.pre_train:
            load_weights(sess,'./pretrained_weights/Resnet50.npy')
        for epoch in range(1, 1+cfg.epochs):
            pbar = tqdm(range(num_train_batch))
            train_epoch_loss, test_epoch_loss = [], []
            sess.run(trainset_init_op)
            for i in pbar:
                _, summary, train_step_loss, global_step_val = sess.run(
                    [train_op, write_op, total_loss, global_step],feed_dict={is_training:True})

                train_epoch_loss.append(train_step_loss)
                summary_writer.add_summary(summary, global_step_val)
                pbar.set_description("train loss: %.2f" %train_step_loss)

            sess.run(testset_init_op)
            for j in range(num_test_batch):
                test_step_loss = sess.run( total_loss, feed_dict={is_training:False})
                test_epoch_loss.append(test_step_loss)

            train_epoch_loss, test_epoch_loss = np.mean(train_epoch_loss), np.mean(test_epoch_loss)
            ckpt_file = "./checkpoint/centernet_test_loss=%.4f.ckpt" % test_epoch_loss
            log_time = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time()))
            print("=> Epoch: %2d Time: %s Train loss: %.2f Test loss: %.2f Saving %s ..."
                            %(epoch, log_time, train_epoch_loss, test_epoch_loss, ckpt_file))
            saver.save(sess, ckpt_file, global_step=epoch)


if __name__ == '__main__': train() 

Это сообщение об ошибке:

Traceback (most recent call last):
  File "train.py", line 142, in <module>
    if __name__ == '__main__': train()
  File "train.py", line 123, in train
    [train_op, write_op, total_loss, global_step],feed_dict={is_training:True})
  File "/home/andrea/virtualenv/lib/python3.7/site-packages/tensorflow/python/client/session.py", line 958, in run
    run_metadata_ptr)
  File "/home/andrea/virtualenv/lib/python3.7/site-packages/tensorflow/python/client/session.py", line 1166, in _run
    self._graph, fetches, feed_dict_tensor, feed_handles=feed_handles)
  File "/home/andrea/virtualenv/lib/python3.7/site-packages/tensorflow/python/client/session.py", line 477, in __init__
    self._fetch_mapper = _FetchMapper.for_fetch(fetches)
  File "/home/andrea/virtualenv/lib/python3.7/site-packages/tensorflow/python/client/session.py", line 266, in for_fetch
    return _ListFetchMapper(fetch)
  File "/home/andrea/virtualenv/lib/python3.7/site-packages/tensorflow/python/client/session.py", line 378, in __init__
    self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in fetches]
  File "/home/andrea/virtualenv/lib/python3.7/site-packages/tensorflow/python/client/session.py", line 378, in <listcomp>
    self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in fetches]
  File "/home/andrea/virtualenv/lib/python3.7/site-packages/tensorflow/python/client/session.py", line 263, in for_fetch
    (fetch, type(fetch)))
TypeError: Fetch argument None has invalid type <class 'NoneType'>

Помогите, пожалуйста!

...