Я пытаюсь сделать вывод из замороженного графика (созданного с помощью freeze_graph.py) тензорной и глубокой модели.Я использую Dataset API для анализа файла test.csv.Поскольку feed dict принимает только массивы numpy, а содержимое набора данных является тензором, я получаю batch next_element = iterator.get_next () и batch = sess.run (next_element), чтобы получить значения numpy, а затем передать их в заполнители, используя dict фида.Но это не дает мне хорошей пропускной способности при работе с большими наборами данных, так как это не эффективный способ преобразовать тензоры в массивы с нуля, а затем передать их в заполнители.Есть ли эффективный способ сделать это.
def input_fn(data_file, num_epochs, shuffle, batch_size):
"""Generate an input function for the Estimator."""
assert tf.gfile.Exists(data_file), (
'%s not found. Please make sure you have run census_dataset.py and '
'set the --data_dir argument to the correct path.' % data_file)
def parse_csv(value):
tf.logging.info('Parsing {}'.format(data_file))
cont_defaults = [ [0.0] for i in range(1,14) ]
cate_defaults = [ [" "] for i in range(1,27) ]
label_defaults = [ [0] ]
column_headers = TRAIN_DATA_COLUMNS
record_defaults = label_defaults + cont_defaults + cate_defaults
columns = tf.decode_csv(value, record_defaults=record_defaults)
all_columns = dict(zip(column_headers, columns))
labels = all_columns.pop(LABEL_COLUMN[0])
features = all_columns
return features, labels
# Extract lines from input files using the Dataset API.
dataset = tf.data.TextLineDataset(data_file)
if shuffle:
dataset = dataset.shuffle(buffer_size=2000)
dataset = dataset.map(parse_csv, num_parallel_calls=8)
# We call repeat after shuffling, rather than before, to prevent separate
# epochs from blending together.
dataset = dataset.repeat(num_epochs)
dataset = dataset.batch(batch_size)
dataset = dataset.prefetch(56)
return dataset
with tf.Session(graph=graph) as sess:
res_dataset = input_fn(predictioninputfile,1,False,batch_size)
iterator = res_dataset.make_one_shot_iterator()
next_element = iterator.get_next()
inference_start = time.time()
for i in range(no_of_batches):
batch=sess.run(next_element)
features,actual_label=batch[0],batch[1]
#print("features",features)
logistic = sess.run(output_tensor, dict(zip(input_tensor,list(features.values()))))