Я работаю над оценкой Tensorflow с использованием RNN (GRUCell).
Я использую zero_state для инициализации первого состояния, оно требует фиксированного размера.
Моя проблема в том, что я хочу иметь возможность использовать оценщик для прогнозирования по одной выборке (batchsize = 1).
Когда он загружает сериализованный оценщик, он жалуется, что размер партии, которую я использую для прогнозирования, не соответствует размеру обучающей партии.
Если я реконструирую оценщик с другим размером пакета, я не могу загрузить то, что было сериализовано.
Есть ли элегантный способ использования zero_state в оценщике?
Я видел некоторые решения, использующие переменную для хранения размера пакета, но с использованием метода feed_dict. Я не вижу, как заставить это работать в контексте оценки.
Вот суть моего простого теста RNN в оценке:
cells = [ tf.nn.rnn_cell.GRUCell(self.getNSize()) for _ in range(self.getNLayers())]
multicell = tf.nn.rnn_cell.MultiRNNCell(cells, state_is_tuple=False)
H_init = tf.Variable( multicell.zero_state( batchsize, dtype=tf.float32 ), trainable=False)
H = tf.Variable( H_init )
Yr, state = tf.nn.dynamic_rnn(multicell, Xo, dtype=tf.float32, initial_state=H)
Может быть, кто-то знает об этом?
EDIT:
Хорошо, я пробую разные вещи по этой проблеме.
Теперь я пытаюсь отфильтровать переменные, которые я загружаю из контрольной точки, чтобы удалить 'H', который используется как внутреннее состояние рекуррентных ячеек. Для прогноза я могу оставить все 0 значений.
Пока что я это сделал:
Сначала я определяю крюк:
class RestoreHook(tf.train.SessionRunHook):
def __init__(self, init_fn):
self.init_fn = init_fn
def after_create_session(self, session, coord=None):
print("--------------->After create session.")
self.init_fn(session)
Тогда в моей model_fn:
if mode == tf.estimator.ModeKeys.PREDICT:
logits = tf.nn.softmax(logits)
# Do not restore H as it's batch size might be different.
vlist = tf.contrib.framework.get_variables_to_restore()
vlist = [ x for x in vlist if x.name.split(':')[0] != 'architecture/H']
init_fn = tf.contrib.framework.assign_from_checkpoint_fn(tf.train.latest_checkpoint(self.modelDir), vlist, ignore_missing_vars=True)
spec = tf.estimator.EstimatorSpec(mode=mode,
predictions = {
'logits': logits,
},
export_outputs={
'prediction': tf.estimator.export.PredictOutput( logits )
},
prediction_hooks=[RestoreHook(init_fn)])
Я взял этот кусок кода у https://github.com/tensorflow/tensorflow/issues/14713
Но это еще не работает. Кажется, он все еще пытается загрузить H из файла ... Я проверил, что его нет во vlist.
Я все еще ищу решение.