Использование набора данных Tensorflow CSV - PullRequest
0 голосов
/ 08 сентября 2018

У меня есть CSV-файл со следующим форматом и данными:

ID  nr1 nr2 nr3 nr4 nr5 next_nr
1   1   2   3   4   5   6
2   2   3   4   5   6   7
3   3   4   5   6   7   8
4   4   5   6   7   8   9
5   5   6   7   8   9   10
6   6   7   8   9   10  11
7   7   8   9   10  11  12
8   8   9   10  11  12  13
9   9   10  11  12  13  14
10  10  11  12  13  14  15

Итак, есть 10 строк с данными о моем поезде. И я хочу использовать tf.contrib.data.CsvDataset для чтения данных. Вот пример кода для его чтения:

import tensorflow as tf
import numpy as np

ITERATOR_BATCH_SIZE = 2
NR_EPOCHS = 3

train1_path = 'train1_short.csv'

dataset = tf.contrib.data.CsvDataset(train1_path,
                                     [tf.float32, tf.float32, tf.float32, tf.float32, tf.float32, tf.float32, tf.float32],
                                     header=True)

dataset = dataset.batch(ITERATOR_BATCH_SIZE)

with tf.Session() as sess:

    for i in range (NR_EPOCHS):
        print('\nepoch: ', i)
        iterator = dataset.make_one_shot_iterator()
        next_element = iterator.get_next()
        while True:            
            try:
              data_and_target = sess.run([next_element])
            except tf.errors.OutOfRangeError:
              break
            print("\n\n", data_and_target)

Когда я запускаю этот код, я ожидаю, что вывод будет включать 2 строки данных в каждом пакете. Но данные, которые я получаю, выглядят довольно странно. Вот вывод из первой эпохи:

epoch:  0


 [(array([1., 2.], dtype=float32), array([1., 2.], dtype=float32), array([2., 3.], dtype=float32), array([3., 4.], dtype=float32), array([4., 5.], dtype=float32), array([5., 6.], dtype=float32), array([6., 7.], dtype=float32))]


 [(array([3., 4.], dtype=float32), array([3., 4.], dtype=float32), array([4., 5.], dtype=float32), array([5., 6.], dtype=float32), array([6., 7.], dtype=float32), array([7., 8.], dtype=float32), array([8., 9.], dtype=float32))]


 [(array([5., 6.], dtype=float32), array([5., 6.], dtype=float32), array([6., 7.], dtype=float32), array([7., 8.], dtype=float32), array([8., 9.], dtype=float32), array([ 9., 10.], dtype=float32), array([10., 11.], dtype=float32))]


 [(array([7., 8.], dtype=float32), array([7., 8.], dtype=float32), array([8., 9.], dtype=float32), array([ 9., 10.], dtype=float32), array([10., 11.], dtype=float32), array([11., 12.], dtype=float32), array([12., 13.], dtype=float32))]


 [(array([ 9., 10.], dtype=float32), array([ 9., 10.], dtype=float32), array([10., 11.], dtype=float32), array([11., 12.], dtype=float32), array([12., 13.], dtype=float32), array([13., 14.], dtype=float32), array([14., 15.], dtype=float32))]

Вместо этого, я бы, например, ожидал, что самой первой партии будет нравиться следующее:

[(array([1., 1., 2., 3., 4., 5., 6], dtype=float32), array([2., 2., 3., 4., 5., 6., 7.], dtype=float32)]

Проблема может быть очень тривиальной, но я просто не понимаю, почему это выглядит так. Может быть, более опытный человек в этой области сразу же это увидит.

1 Ответ

0 голосов
/ 09 сентября 2018

Каждая запись CsvDatset должна быть преобразована в тензор. Дайте мне знать, если это работает для вас:

dataset = tf.contrib.data.CsvDataset(train1_path,
                                     [tf.float32, tf.float32, tf.float32, tf.float32, tf.float32, tf.float32, tf.float32],
                                     header=True, field_delim=' ')

dataset = dataset.map(lambda *x: tf.convert_to_tensor(x))
dataset = dataset.batch(ITERATOR_BATCH_SIZE)

with tf.Session() as sess:
    for i in range (NR_EPOCHS):
        print('\nepoch: ', i)
        iterator = dataset.make_one_shot_iterator()
        next_element = iterator.get_next()
        while True:            
            try:
              data_and_target = sess.run(next_element)
            except tf.errors.OutOfRangeError:
              break
            print("\n\n", data_and_target)

Для моего теста мне пришлось установить аргумент field_delim, чтобы он заработал.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...