Я пытаюсь обучить cente rnet с помощью чистой реализации тензорного потока, но при попытке запустить файл python возникает проблема, которую я не могу решить. Это репо с исходным кодом: https://github.com/MioChiu/TF_CenterNet
Это код:
sess.run(tf.compat.v1.global_variables_initializer())
if cfg.pre_train:
load_weights(sess,'./pretrained_weights/Resnet50.npy')
for epoch in range(1, 1+cfg.epochs):
pbar = tqdm(range(num_train_batch))
train_epoch_loss, test_epoch_loss = [], []
sess.run(trainset_init_op)
for i in pbar:
_, summary, train_step_loss, global_step_val = sess.run(
[train_op, write_op, total_loss, global_step],feed_dict={is_training:True})
train_epoch_loss.append(train_step_loss)
summary_writer.add_summary(summary, global_step_val)
pbar.set_description("train loss: %.2f" %train_step_loss)
sess.run(testset_init_op)
for j in range(num_test_batch):
test_step_loss = sess.run( total_loss, feed_dict={is_training:False})
test_epoch_loss.append(test_step_loss)
train_epoch_loss, test_epoch_loss = np.mean(train_epoch_loss), np.mean(test_epoch_loss)
ckpt_file = "./checkpoint/centernet_test_loss=%.4f.ckpt" % test_epoch_loss
log_time = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time()))
print("=> Epoch: %2d Time: %s Train loss: %.2f Test loss: %.2f Saving %s ..."
%(epoch, log_time, train_epoch_loss, test_epoch_loss, ckpt_file))
saver.save(sess, ckpt_file, global_step=epoch)
if __name__ == '__main__': train()
Это сообщение об ошибке:
Traceback (most recent call last):
File "train.py", line 142, in <module>
if __name__ == '__main__': train()
File "train.py", line 123, in train
[train_op, write_op, total_loss, global_step],feed_dict={is_training:True})
File "/home/andrea/virtualenv/lib/python3.7/site-packages/tensorflow/python/client/session.py", line 958, in run
run_metadata_ptr)
File "/home/andrea/virtualenv/lib/python3.7/site-packages/tensorflow/python/client/session.py", line 1166, in _run
self._graph, fetches, feed_dict_tensor, feed_handles=feed_handles)
File "/home/andrea/virtualenv/lib/python3.7/site-packages/tensorflow/python/client/session.py", line 477, in __init__
self._fetch_mapper = _FetchMapper.for_fetch(fetches)
File "/home/andrea/virtualenv/lib/python3.7/site-packages/tensorflow/python/client/session.py", line 266, in for_fetch
return _ListFetchMapper(fetch)
File "/home/andrea/virtualenv/lib/python3.7/site-packages/tensorflow/python/client/session.py", line 378, in __init__
self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in fetches]
File "/home/andrea/virtualenv/lib/python3.7/site-packages/tensorflow/python/client/session.py", line 378, in <listcomp>
self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in fetches]
File "/home/andrea/virtualenv/lib/python3.7/site-packages/tensorflow/python/client/session.py", line 263, in for_fetch
(fetch, type(fetch)))
TypeError: Fetch argument None has invalid type <class 'NoneType'>
Помогите, пожалуйста!