Ошибка Tensorflow py_function в пользовательской потере Keras при манипулировании формой ввода, объект NoneType не может быть интерпретирован как целое число - PullRequest
0 голосов
/ 17 февраля 2019

У меня проблема при создании пользовательской функции потерь в Keras для модели оценки позы.Функция потерь принимает вектор (36,0) в качестве метки истинности земли и другой вектор такой же формы, что и прогноз.Но функции Keras передают их в качестве тестеров.Я вычисляю потери по тепловым картам, а не по векторам, поэтому функция потерь преобразует вектор (36,0) в матрицу (224 224) с помощью некоторого сложного преобразования.

Моя проблема возникает, когда я пытаюсь манипулировать значениями вводавектор с использованием кода Python, я попытался py_function и после многих неудачных попыток, похоже, работает.Но он все равно не может пройти этап компиляции.это дает ошибку при среднеквадратичной ошибке функции Кераса, что существует некоторый тип объекта NoneType.И когда я попытался увидеть выходные данные функции py_function, я увидел, что она возвращает объект «EagerPyFunction».

Можете ли вы помочь мне найти ошибку и найти способ, как это сделать, не переопределяя функцию heatmap при использованииФункции тензорного потока?

def custom_loss_py__(y_true, y_pred):



 loss = 0

  with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    b= pair_idx(36,1)[:,0,np.newaxis]
    c= pair_idx(36,1)[:,1,np.newaxis]
    print("paring done")
    result = tf.gather_nd(y_true, tf.stack((b, c),1))
    result_ = tf.gather_nd(y_pred, tf.stack((b, c),1))
    print("gathering done")
    heatmap_true = tf.py_function(get_heatmap, [result, 224, 224], tf.float16)
    print(heatmap_true.op.type)
    heatmap_pred = tf.py_function(get_heatmap, [result_, 224, 224], tf.float16)
    print("sos")

  loss += keras.losses.mean_squared_error(heatmap_true, heatmap_pred)
  print("loss")
  return loss
model.compile(loss=custom_loss_py__, optimizer='adam')

выше приведен используемый код, а ниже вывод:

paring done
gathering done
EagerPyFunc
sos
loss

---------------------------------------------------------------------------

TypeError                                 Traceback (most recent call last)

<ipython-input-33-3f07b5ecc513> in <module>()
----> 1 model.compile(loss=custom_loss_py__, optimizer='adam')

/usr/local/lib/python3.6/dist-packages/keras/engine/training.py in compile(self, optimizer, loss, metrics, loss_weights, sample_weight_mode, weighted_metrics, target_tensors, **kwargs)
    340                 with K.name_scope(self.output_names[i] + '_loss'):
    341                     output_loss = weighted_loss(y_true, y_pred,
--> 342                                                 sample_weight, mask)
    343                 if len(self.outputs) > 1:
    344                     self.metrics_tensors.append(output_loss)

/usr/local/lib/python3.6/dist-packages/keras/engine/training_utils.py in weighted(y_true, y_pred, weights, mask)
    425             weight_ndim = K.ndim(weights)
    426             score_array = K.mean(score_array,
--> 427                                  axis=list(range(weight_ndim, ndim)))
    428             score_array *= weights
    429             score_array /= K.mean(K.cast(K.not_equal(weights, 0), K.floatx()))

TypeError: 'NoneType' object cannot be interpreted as an integer
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...