Я загружаю файл CSV и хочу настроить tf.data.TextLineDataset для реализации функции карты в зависимости от атрибута самого класса.Например, я хочу удалить минимальное значение ранее нарисованных элементов.Единственный способ, который я нашел, это что-то вроде
class myDataSet(tf.data.TextLineDataset):
__slots__ = ['dim', 'batch_size', 'shuffle', 'tmin', 'finaldataset']
def __init__(self, batch_size=5, dim=14,
shuffle=True, filepath="test.csv"):
super(myDataSet, self).__init__([filepath])
self.dim = dim
self.batch_size = batch_size
self.shuffle = shuffle
self.tmin = tf.Variable(initial_value=[10.])
self.finaldataset = self.skip(1)
self.finaldataset = self.finaldataset.repeat()
self.finaldataset = self.finaldataset.shuffle(10 * self.batch_size)
self.finaldataset = self.finaldataset.prefetch(20 * self.batch_size)
self.finaldataset = self.finaldataset.batch(self.batch_size)
self.finaldataset = self.finaldataset.map(self.parsefn, num_parallel_calls=8)
def parsefn(self, tf_string):
data = tf.decode_csv(tf_string, record_defaults=[[0.]] * self.dim)
input_tens = data[:-1]
target_tens = data[-1]
input_tens = tf.stack(input_tens , axis=-1)
self.tmin = 0.8 * self.tmin + 0.2 * tf.reduce_min(target_tens, axis=0)
target_tens = target_tens - self.tmin
return input_tens, target_tens
def getminmax(self):
return self.tmin
. Этот код работает, даже если он не очень элегантен для меня, потому что у меня есть finaldataset внутри myDataSet.Однако я не могу получить доступ к значению tmin из основного кода.На самом деле,
sess.run(datasetG.getminmax())
выдает следующую ошибку:
Fetch argument <tf.Tensor 'add:0' shape=(2,) dtype=float32> cannot be interpreted as a Tensor.
Как я могу прочитать этот атрибут?Большое спасибо
- РЕДАКТИРОВАТЬ -
Этот код выполняется перед вызовом datasetG.getminmax ()
with tf.device('/cpu:0'):
datasetG = datamanagement.myDataSet(batch_size=5, dim=2,
shuffle=True, filepath="AB_fit.csv")
iterG = datasetG.finaldataset.make_initializable_iterator()
elG = iterG.get_next()
with tf.Session() as sess:
sess.run(tf.initializers.global_variables())
sess.run(iterG.initializer)
print(sess.run(elG))
sess.run(datasetG.getminmax())