Назначить тензор в SessionRunHook - PullRequest
0 голосов
/ 25 февраля 2020

Я пытаюсь изменить веса тензора DeepSpeech2 модели , которая использует tf.estimator.Estimator. Я написал SessionRunHook после этого поста , но я получаю сообщение об ошибке при выполнении операции назначения. Я передаю хук в EstimatorSpe c и при вызове предиката (). Как я могу это исправить?

SessionRunHook:

class WeightPruningHook(tf.train.SessionRunHook):
    def __init__(self, name, minVal, maxVal):
        self.name = name
        self.epsilon = epsilon
        self.minVal = minVal
        self.maxVal = maxVal
        self.weights = tf.Variable(0.0)
        self.prune_op = tf.assign(self.weights, tf.clip_by_value(self.weights, self.minVal, self.maxVal))

    def before_run(self, run_context):
      self.weights = tf.get_default_graph().get_tensor_by_name(self.name+":0")
      run_context.session.run(self.prune_op)

Ошибка:

ValueError: Fetch argument <tf.Tensor 'Assign:0' shape=() dtype=float32_ref> cannot be interpreted as a Tensor. (Tensor Tensor("Assign:0", shape=(), dtype=float32_ref) is not an element of this graph.)
...