Как пропустить ошибочные элементы на уровне io в apache beam с Dataflow? - PullRequest
1 голос
/ 25 февраля 2020

Я делаю некоторый анализ tfrecords, хранящихся в GCP, но некоторые из tfrecords внутри файлов повреждены, поэтому, когда я запускаю свой конвейер и получаю более четырех ошибок, мой конвейер разрывается из-за this, Я думаю, что это ограничение DataFlowRunner, а не луча.

Вот мой сценарий обработки

import argparse
import apache_beam as beam
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.metrics.metric import Metrics

from apache_beam.runners.direct import direct_runner
import tensorflow as tf

input_ = "path_to_bucket"


def _parse_example(serialized_example):
  """Return inputs and targets Tensors from a serialized tf.Example."""
  data_fields = {
      "inputs": tf.io.VarLenFeature(tf.int64),
      "targets": tf.io.VarLenFeature(tf.int64)
  }
  parsed = tf.io.parse_single_example(serialized_example, data_fields)
  inputs = tf.sparse.to_dense(parsed["inputs"])
  targets = tf.sparse.to_dense(parsed["targets"])
  return inputs, targets


class MyFnDo(beam.DoFn):

  def __init__(self):
    beam.DoFn.__init__(self)
    self.input_tokens = Metrics.distribution(self.__class__, 'input_tokens')
    self.output_tokens = Metrics.distribution(self.__class__, 'output_tokens')
    self.num_examples = Metrics.counter(self.__class__, 'num_examples')
    self.decode_errors = Metrics.counter(self.__class__, 'decode_errors')

  def process(self, element):
    # inputs = element.features.feature['inputs'].int64_list.value
    # outputs = element.features.feature['outputs'].int64_list.value
    try:
      inputs, outputs = _parse_example(element)
      self.input_tokens.update(len(inputs))
      self.output_tokens.update(len(outputs))
      self.num_examples.inc()
    except Exception:
      self.decode_errors.inc()



def main(argv):
  parser = argparse.ArgumentParser()
  parser.add_argument('--input', dest='input', default=input_, help='input tfrecords')
  # parser.add_argument('--output', dest='output', default='gs://', help='output file')

  known_args, pipeline_args = parser.parse_known_args(argv)
  pipeline_options = PipelineOptions(pipeline_args)

  with beam.Pipeline(options=pipeline_options) as p:
    tfrecords = p | "Read TFRecords" >> beam.io.ReadFromTFRecord(known_args.input,
                                                                 coder=beam.coders.ProtoCoder(tf.train.Example))
    tfrecords | "count mean" >> beam.ParDo(MyFnDo())


if __name__ == '__main__':
    main(None)

, так что в основном, как я могу пропустить поврежденные tfrecords и записать их номера во время моего анализа

1 Ответ

0 голосов
/ 26 февраля 2020

С этим возникла концептуальная проблема: beam.io.ReadFromTFRecord считывает данные из одной tfrecords (которые могли быть переданы в несколько файлов), тогда как я давал список многих отдельных tfrecords, и, следовательно, это вызывало ошибку. Переход на ReadAllFromTFRecord с ReadFromTFRecord решил мою проблему.

p = beam.Pipeline(runner=direct_runner.DirectRunner())
tfrecords = p | beam.Create(tf.io.gfile.glob(input_)) | ReadAllFromTFRecord(coder=beam.coders.ProtoCoder(tf.train.Example))
tfrecords | beam.ParDo(MyFnDo())
...