Я пытаюсь изменить веса тензора DeepSpeech2 модели , которая использует tf.estimator.Estimator. Я написал SessionRunHook после этого поста , но я получаю сообщение об ошибке при выполнении операции назначения. Я передаю хук в EstimatorSpe c и при вызове предиката (). Как я могу это исправить?
SessionRunHook:
class WeightPruningHook(tf.train.SessionRunHook):
def __init__(self, name, minVal, maxVal):
self.name = name
self.epsilon = epsilon
self.minVal = minVal
self.maxVal = maxVal
self.weights = tf.Variable(0.0)
self.prune_op = tf.assign(self.weights, tf.clip_by_value(self.weights, self.minVal, self.maxVal))
def before_run(self, run_context):
self.weights = tf.get_default_graph().get_tensor_by_name(self.name+":0")
run_context.session.run(self.prune_op)
Ошибка:
ValueError: Fetch argument <tf.Tensor 'Assign:0' shape=() dtype=float32_ref> cannot be interpreted as a Tensor. (Tensor Tensor("Assign:0", shape=(), dtype=float32_ref) is not an element of this graph.)