Tensorflow Estimator - манипулировать функцией затрат с весами - PullRequest
0 голосов
/ 21 января 2019

Я пытаюсь построить собственный оценщик, который будет делать следующее.На каждом шаге пакетной подпорки я хочу выполнить некоторые манипуляции со стоимостью, например:

loss = tf.squared_difference(y_pred, y)
weighted_loss = tf.multiply(weights, loss)
cost = tf.reduce_sum(weighted_loss) / batch_size

Матрица «весов» здесь представляет собой некоторые внешние данные (в основном она просто обнуляет некоторые элементы в каждомя не хочу поддерживать их), но это внешние данные, которые мне нужно будет предоставить функции model_fn для каждого пакета на этапе обучения.Как мне это сделать?Как я могу найти записи, которые являются частью текущей обучающей партии, и предоставить model_fn соответствующую матрицу весов для этих записей?

...