Оценка тензора API набора данных TensorFlow в функции карты - PullRequest
0 голосов
/ 29 августа 2018

У меня есть следующая функция ввода набора данных для создания генератора набора данных.

def dataset_input_fn(filenames, shuffle, batch_size, sample):
    def parser(record):
        features = {
            'mean_rgb': tf.FixedLenFeature([1024], tf.float32),
            'category': tf.FixedLenFeature([], tf.int64)
        }
        parsed = tf.parse_single_example(record, features)

        vrv = parsed['mean_rgb']
        label = tf.cast(parsed['category'], tf.int32)
        return {"mean_rgb": vrv}, label

    dataset = tf.data.TFRecordDataset(filenames)
    dataset = dataset.map(parser)
    if sample:
        dataset = dataset.flat_map(
            lambda x, y: tf.data.Dataset.from_tensors((x, y)).repeat(oversample_classes(y))
        )
        dataset = dataset.filter(undersampling_filter)
    dataset = dataset.shuffle(buffer_size=100 * batch_size)
    dataset = dataset.batch(batch_size).repeat(1)
    iterator = dataset.make_one_shot_iterator()
    features, labels = iterator.get_next()
    return features, labels

Я пытаюсь следовать этому коду для перевыбора / выборки данных на основе метки. В моей функции dataset.flat_map я перебираю каждую метку и хотел бы определить, как часто ее повторять. Тем не менее, у является тензором, и я не могу оценить его как целое число. Когда я пытаюсь sess.run(label), я получаю

ValueError: аргумент Fetch не может быть истолковано как Тензор. (Тензор Тензор ("arg1: 0", shape = (), dtype = int32) не является элементом этого графа.)

...