Как расширить функциональность класса tf.data.TextLineDataset? - PullRequest
0 голосов
/ 05 февраля 2019

Я загружаю файл CSV и хочу настроить tf.data.TextLineDataset для реализации функции карты в зависимости от атрибута самого класса.Например, я хочу удалить минимальное значение ранее нарисованных элементов.Единственный способ, который я нашел, это что-то вроде

class myDataSet(tf.data.TextLineDataset):
__slots__ = ['dim', 'batch_size', 'shuffle', 'tmin', 'finaldataset']

def __init__(self, batch_size=5, dim=14,
             shuffle=True, filepath="test.csv"):
    super(myDataSet, self).__init__([filepath])

    self.dim = dim
    self.batch_size = batch_size
    self.shuffle = shuffle

    self.tmin = tf.Variable(initial_value=[10.])

    self.finaldataset = self.skip(1)

    self.finaldataset = self.finaldataset.repeat()
    self.finaldataset = self.finaldataset.shuffle(10 * self.batch_size)
    self.finaldataset = self.finaldataset.prefetch(20 * self.batch_size)

    self.finaldataset = self.finaldataset.batch(self.batch_size) 
    self.finaldataset = self.finaldataset.map(self.parsefn, num_parallel_calls=8)

def parsefn(self, tf_string):
    data = tf.decode_csv(tf_string, record_defaults=[[0.]] * self.dim)

    input_tens = data[:-1]
    target_tens = data[-1]

    input_tens = tf.stack(input_tens , axis=-1)

    self.tmin = 0.8 * self.tmin + 0.2 * tf.reduce_min(target_tens, axis=0)

    target_tens = target_tens - self.tmin

    return input_tens, target_tens

def getminmax(self):
    return self.tmin

. Этот код работает, даже если он не очень элегантен для меня, потому что у меня есть finaldataset внутри myDataSet.Однако я не могу получить доступ к значению tmin из основного кода.На самом деле,

sess.run(datasetG.getminmax())

выдает следующую ошибку:

Fetch argument <tf.Tensor 'add:0' shape=(2,) dtype=float32> cannot be interpreted as a Tensor.

Как я могу прочитать этот атрибут?Большое спасибо

- РЕДАКТИРОВАТЬ -

Этот код выполняется перед вызовом datasetG.getminmax ()

with tf.device('/cpu:0'):
datasetG = datamanagement.myDataSet(batch_size=5, dim=2,
             shuffle=True, filepath="AB_fit.csv")

iterG = datasetG.finaldataset.make_initializable_iterator()

elG = iterG.get_next()

with tf.Session() as sess:
    sess.run(tf.initializers.global_variables())
    sess.run(iterG.initializer)
    print(sess.run(elG))
    sess.run(datasetG.getminmax())
...