Как использовать предварительно обученную сеть TensorFlow в функции потерь Keras - PullRequest
0 голосов
/ 12 мая 2018

У меня есть предварительно обученная сеть, которую я хочу использовать для оценки потерь в моей сети Keras.Предварительно обученная сеть была обучена с использованием TensorFlow, и я просто хочу использовать ее как часть моего расчета потерь.

Код моей пользовательской функции потерь в настоящий момент:

def custom_loss_func(y_true, y_pred):
   # Get saliency of both true and pred
   sal_true = deep_gaze.get_saliency_map(y_true)
   sal_pred = deep_gaze.get_saliency_map(y_pred)

   return K.mean(K.square(sal_true-sal_pred))

Где deep_gaze - это объект, предназначенный для управления доступом к внешней предварительно обученной сети, которую я использую.

Это определяется следующим образом:

class DeepGaze(object):
  CHECK_POINT = os.path.join(os.path.dirname(__file__), 'DeepGazeII.ckpt')  # DeepGaze II

def __init__(self):
    print('Loading Deep Gaze II...')

    with tf.Graph().as_default() as deep_gaze_graph:
        saver = tf.train.import_meta_graph('{}.meta'.format(self.CHECK_POINT))

        self.input_tensor = tf.get_collection('input_tensor')[0]
        self.log_density_wo_centerbias = tf.get_collection('log_density_wo_centerbias')[0]

    self.tf_session = tf.Session(graph=deep_gaze_graph)
    saver.restore(self.tf_session, self.CHECK_POINT)

    print('Deep Gaze II Loaded')

'''
Returns the saliency map of the input data. 
input format is a 4d array [batch_num, height, width, channel]
'''
def get_saliency_map(self, input_data):
    log_density_prediction = self.tf_session.run(self.log_density_wo_centerbias,
                                                 {self.input_tensor: input_data})

    return log_density_prediction

Когда я запускаю это, я получаю сообщение об ошибке:

TypeError: Значение фида не может быть tf.Tensorобъект.Приемлемые значения подачи включают скаляры Python, строки, списки, numy ndarrays или TensorHandles.

Что я делаю неправильно?Есть ли способ оценить сеть на объекте TensorFlow для другой сети (созданной Keras с бэкэндом TensorFlow).

Заранее спасибо.

1 Ответ

0 голосов
/ 13 мая 2018

Есть две основные проблемы:

  • Когда вы звоните get_saliency_map с input_data=y_true, вы передаете тензор input_data другому тензору self.input_tensor, и это не такдействительный.Кроме того, эти тензоры не содержат значения во время создания графа, а скорее определяют вычисления, которые в конечном итоге приведут к значению.

  • Даже если вы можете получить вывод из get_saliency_mapваш код все равно не будет работать, потому что эта функция отключает ваш граф TensorFlow (он не возвращает тензор), и вся логика должна находиться внутри графа.Каждый тензор должен быть рассчитан на основе других доступных тензоров на графике.

Решение этой проблемы состоит в определении модели, выдающей self.log_density_wo_centerbias в графе, где вы определяете свою потерюиспользовать тензоры y_true и y_pred непосредственно в качестве входных данных без отключения графика.

...