У меня есть предварительно обученная сеть, которую я хочу использовать для оценки потерь в моей сети Keras.Предварительно обученная сеть была обучена с использованием TensorFlow, и я просто хочу использовать ее как часть моего расчета потерь.
Код моей пользовательской функции потерь в настоящий момент:
def custom_loss_func(y_true, y_pred):
# Get saliency of both true and pred
sal_true = deep_gaze.get_saliency_map(y_true)
sal_pred = deep_gaze.get_saliency_map(y_pred)
return K.mean(K.square(sal_true-sal_pred))
Где deep_gaze - это объект, предназначенный для управления доступом к внешней предварительно обученной сети, которую я использую.
Это определяется следующим образом:
class DeepGaze(object):
CHECK_POINT = os.path.join(os.path.dirname(__file__), 'DeepGazeII.ckpt') # DeepGaze II
def __init__(self):
print('Loading Deep Gaze II...')
with tf.Graph().as_default() as deep_gaze_graph:
saver = tf.train.import_meta_graph('{}.meta'.format(self.CHECK_POINT))
self.input_tensor = tf.get_collection('input_tensor')[0]
self.log_density_wo_centerbias = tf.get_collection('log_density_wo_centerbias')[0]
self.tf_session = tf.Session(graph=deep_gaze_graph)
saver.restore(self.tf_session, self.CHECK_POINT)
print('Deep Gaze II Loaded')
'''
Returns the saliency map of the input data.
input format is a 4d array [batch_num, height, width, channel]
'''
def get_saliency_map(self, input_data):
log_density_prediction = self.tf_session.run(self.log_density_wo_centerbias,
{self.input_tensor: input_data})
return log_density_prediction
Когда я запускаю это, я получаю сообщение об ошибке:
TypeError: Значение фида не может быть tf.Tensorобъект.Приемлемые значения подачи включают скаляры Python, строки, списки, numy ndarrays или TensorHandles.
Что я делаю неправильно?Есть ли способ оценить сеть на объекте TensorFlow для другой сети (созданной Keras с бэкэндом TensorFlow).
Заранее спасибо.