В моей модели (используйте Tensorflow Estimator) я хочу, чтобы подача данных была более динамичной. Например, предоставить разные данные во время обучения (на разных этапах обучения предоставляются разные данные для модели).
Один пример, подобный следующим кодам. get_input_fn предоставляет функции input_fn и _parse для обработки функций. _py_process_line_pair внутри _parse выполняет точную обработку. Но я не уверен, как передать global_step (или связанный параметр в _py_process_line_pair)
def _parse(self, features):
def _py_process_line_pair(src_wds, trg_wds, cur_training_steps):
.... (some processing depends on cur_training_steps)
return np.array(src_ids, np.int32), np.array(trg_ids, np.int32)
src_wds, trg_wds = features['src_wds'], features['trg_wds']
src_ids, trg_ids = tf.py_func(
_py_process_line_pair,
[src_wds, trg_wds],
[tf.int32, tf.int32])
src_ids.set_shape(
[self.flags.max_src_len])
trg_ids.set_shape(
[self.flags.max_trg_len])
output = {
'src_ids': src_ids,
'trg_ids': trg_ids,
}
return output
def get_input_fn(self, is_training, input_files, num_cpu_threads):
def input_fn(params):
batch_size = params['batch_size']
if is_training:
d = tf.data.Dataset.from_tensor_slices(tf.constant(tf.gfile.Glob(input_files)))
d = d.repeat()
d = d.shuffle(buffer_size=len(input_files))
cycle_length = min(num_cpu_threads, len(input_files))
d = d.apply(
tf.data.experimental.parallel_interleave(
tf.data.TFRecordDataset,
sloppy=is_training,
cycle_length=cycle_length))
d = d.shuffle(buffer_size=100)
else:
d = tf.data.TFRecordDataset(input_files)
d = d.apply(
tf.data.experimental.map_and_batch(
lambda record: self._parse(tf.parse_single_example(record, self.feature_set)),
batch_size=batch_size,
num_parallel_batches=num_cpu_threads,
drop_remainder=is_training))
return d
return input_fn