вывод из замороженной модели в тензорном потоке с использованием API набора данных - PullRequest
0 голосов
/ 21 сентября 2018

Я пытаюсь сделать вывод из замороженного графика (созданного с помощью freeze_graph.py) тензорной и глубокой модели.Я использую Dataset API для анализа файла test.csv.Поскольку feed dict принимает только массивы numpy, а содержимое набора данных является тензором, я получаю batch next_element = iterator.get_next () и batch = sess.run (next_element), чтобы получить значения numpy, а затем передать их в заполнители, используя dict фида.Но это не дает мне хорошей пропускной способности при работе с большими наборами данных, так как это не эффективный способ преобразовать тензоры в массивы с нуля, а затем передать их в заполнители.Есть ли эффективный способ сделать это.

def input_fn(data_file, num_epochs, shuffle, batch_size):
  """Generate an input function for the Estimator."""
  assert tf.gfile.Exists(data_file), (
      '%s not found. Please make sure you have run census_dataset.py and '
      'set the --data_dir argument to the correct path.' % data_file)

  def parse_csv(value):
    tf.logging.info('Parsing {}'.format(data_file))
    cont_defaults = [ [0.0] for i in range(1,14) ]
    cate_defaults = [ [" "] for i in range(1,27) ]
    label_defaults = [ [0] ]
    column_headers = TRAIN_DATA_COLUMNS
    record_defaults = label_defaults + cont_defaults + cate_defaults
    columns = tf.decode_csv(value, record_defaults=record_defaults)
    all_columns = dict(zip(column_headers, columns))
    labels = all_columns.pop(LABEL_COLUMN[0])
    features = all_columns
    return features, labels

  # Extract lines from input files using the Dataset API.
  dataset = tf.data.TextLineDataset(data_file)

  if shuffle:
    dataset = dataset.shuffle(buffer_size=2000)

  dataset = dataset.map(parse_csv, num_parallel_calls=8)

  # We call repeat after shuffling, rather than before, to prevent separate
  # epochs from blending together.
  dataset = dataset.repeat(num_epochs)
  dataset = dataset.batch(batch_size)
  dataset = dataset.prefetch(56)
  return dataset 



with tf.Session(graph=graph) as sess:
  res_dataset = input_fn(predictioninputfile,1,False,batch_size)
  iterator = res_dataset.make_one_shot_iterator()
  next_element = iterator.get_next()
  inference_start = time.time()
  for i in range(no_of_batches):
    batch=sess.run(next_element)
    features,actual_label=batch[0],batch[1]
    #print("features",features)
    logistic = sess.run(output_tensor, dict(zip(input_tensor,list(features.values()))))
...