У меня есть пользовательский оценщик с tf.estimator.BestExporter
в виде:
exporter = tf.estimator.BestExporter(
name="best_exporter",
serving_input_receiver_fn=serving_input_receiver_fn,
exports_to_keep=5
) # this will keep the 5 best checkpoints
, поэтому при model_dir
у меня теперь есть:
# (inside model_dir/)
...
export/
- best_exporter/
- <timestamp>
- variables/
- variables.data-00000-of-00001
- variables.index
- saved_model.pb
Я могу загрузить и использоватьмой экспортированный оценщик через
predict_fn = predictor.from_saved_model(os.path.join(best_exporter_dir, timestamp))
Я хотел бы иметь возможность обновить значения этого оценщика (например, веса некоторого слоя some_layer/kernel:0
)
Есть связанный (но не идентичный) GitHub выпуск , в котором вместо этого рассматривается, как можно это сделать, с помощью модели контрольные точки ( соответствующая часть выпуска ), которую автор подтвердил для работы с TensorFlow v1.4.
Я пытался соткать соответствующие части этого кода, чтобы хотя бы иметь возможность обновить некоторые веса:
def load_estimator_graph(export_dir:str)->None:
'''Solves import issues when using tf.estimator.(Best)Exporter for saving
models rather than using the last checkpoint.
Arguments:
export_dir (str): the full path to exported tf.estimator model
Returns:
None
'''
with tf.Session(graph=tf.Graph()) as sess:
meta_graph = tf.saved_model.loader.load(sess, ['serve'], export_dir)
with tf.Session() as sess:
loaded_graph = tf.train.import_meta_graph(meta_graph)
def lazy_fetch_variable_values(variable_names:list)->dict:
'''
Notes:
"lazy" refers to:
1. the use of `tf.initialize_all_variables()` to ensure
variables have values
2. the use of `tf.trainable_variables()` to search the likely
releveant values
Arguments:
variable_names (list): list of variable names (str) to retrieve from the
default tensorflow graph
Returns:
variables (dict): key:value of the variables and the values as pythonic
data types.
'''
init_op = tf.initialize_all_variables()
variables = {}
with tf.Session() as sess:
sess.run(init_op)
tvars = tf.trainable_variables()
tvars_vals = sess.run(tvars)
for var, val in zip(tvars, tvars_vals):
if var.name in variable_names:
variables[var.name] = val
return variables
def lazy_set_variable_values(variables_to_set:dict):
'''
Arguments:
variables_to_set (dict): variable_name, variable_value pairs for which
to be updated in the graph
'''
init_op = tf.initialize_all_variables()
with tf.Session() as sess:
sess.run(init_op)
tf_global_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
for var_to_find, val_to_set in variables_to_set.items():
var = [v for v in tf_global_vars if v.name == var_to_find][0]
sess.run(var)
var = var.assign(val_to_set)
sess.run(var)
, а затем что-то вроде:
load_estimator_graph(best_exported_model_dir)
layer_name = 'some_layer/kernel:0'
weights = lazy_fetch_variable_values([layer_name])[layer_name]
new_weights = np.copy(weights)
new_weights = 0 # <-- np.ndarray, this sets _element-wise_ all values to 0,
# has same shape as original weight tensor
lazy_set_varriable_values({layer_name: new_weights})
with tf.Session() as session:
sess.run(tf.initialize_all_variables())
saver = tf.train.Saver()
saver.save(sess, os.path.join(best_exported_model_dir, '..', 'best_updated'))
Следует отметить, что это чтение в модели tf.estimator.BestExporter
и попытка экспорта в контрольную точку.
Поэтому, если я пытаюсь восстановить контрольную точку:
est = tf.estimator.Estimator(
model_fn = model_fn,
model_dir = os.path.join(best_exported_model_dir, '..', 'best_updated'),
config = tf.estimator.RunConfig(**_config['RunConfig']), # same as runtime call
params = _config, # same as runtime call
)
eval_fn = lambda : input_fn(mode='eval')
est.evaluate(eval_fn)
, я получаю:
ValueError Input 0 of layer some_layer is incompatible with the layer: : expected min_ndim=2, found ndim=1. Full shape received: [an_integer]
где в приведенном выше коде
weights.shape[0] == new_weights.shape[0] == an_integer
суть проблемы
В идеале я предпочел бы сохранить обновленныймодель в той же форме, что и tf.estimator.BestExporter
и tf.estimator.Estimator.export_savedmodel
.
Однако для перечисленных выше методов экспорта требуется экземпляр estimator
и соответствующий serving_input_receiver_fn
.Метод predictor.from_saved_model(exported_dir)
не инициализирует оценщик!так что, кажется, нет простого способа сделать это.
Примечания: - predictor
взято из from tensorflow.contrib import predictor
- Я хотел бы импортировать из экспортированной модели, обновить некоторые значения (например, смещение / веса)), а затем экспортируйте в той же форме (не перезаписывайте исходную модель).