A tf.estimator
input_fn
подпись может выглядеть примерно так:
def input_fn(files:list, params:dict):
dataset = tf.data.TFRecordDataset(files)
.map(lambda record: parse_record_fn(record))
if params['mode'] == 'train':
# train specific things
# ...
Такое определение позволяет одному затем построить все свои input_fn
s следующим образом:
train_fn = lambda: input_fn(files['training_set'], {**params, **{"mode": "train"}})
valid_fn = lambda: input_fn(files['validation_set'], {**params, **{"mode": "eval"}})
test_fn = lambda: input_fn(files['test_set'], {**params, **{"mode": "test"}})
train_spec = tf.estimator.TrainSpec(input_fn=train_fn, ...)
eval_spec = tf.estimator.EvalSpec(input_fn=valid_fn, ...)
Мой вопрос заключается в том, как изменить сигнатуру input_fn
, чтобы учесть вариации на основе эпох.Я понимаю, что это может создать бутылочное горлышко, но было бы неплохо, если бы я мог сделать что-то вроде:
def input_fn(...):
# see above
epoch = params["epoch"]
if epoch % 100 == 0:
# modify or make a new dataset
# ...
return dataset.make_one_shot_iterator().get_next()
Ключ в том, чтобы убедиться, что input_fn
все еще совместим с:
tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)