Возврат определенных элементов с помощью API набора данных - PullRequest
0 голосов
/ 28 марта 2019

Я написал файл tfrecord, в котором у меня есть изображения и их ярлыки. Затем я могу подобрать их, используя

    def parserTrain(record):
        keys_to_features = {
            "image_raw": tf.FixedLenFeature((), tf.string, default_value=""),
            "label": tf.FixedLenFeature((), tf.int64,
                                        default_value=tf.zeros([], dtype=tf.int64)),
        }
        parsed = tf.parse_single_example(record, keys_to_features)

        # Perform additional preprocessing on the parsed data.
        image = tf.image.decode_jpeg(parsed["image_raw"])
        image = tf.reshape(image, [256, 256, 3])

        image = tf.transpose(image, perm=[2, 0, 1])  # channels first
        image = tf.truediv(image, 255.0)
        label = tf.cast(parsed["label"], tf.int32)

        return {"image": image}, label

    # Set up training input function.
    def train_input_fn():
        """Prepare data for training."""
        train_tfrecord = 'Dataset/train_images.tfrecords'

        dataset = tf.data.TFRecordDataset(train_tfrecord)
        dataset = dataset.map(parserTrain)

, после этого я хочу отфильтровать некоторые примеры, используя, вероятно, что-то вроде этого:

def f(x):
    return x[1] == 1


ds1 = dataset.filter(f)

но я получаю эту ошибку:

TypeError: f () принимает 1 позиционный аргумент, но 2 были даны

Ответы [ 2 ]

0 голосов
/ 01 апреля 2019

Отвечая на мой вопрос, когда я нашел ответ. правильный синтаксис функции фильтра для набора данных кортежей следующий:

def f(im, label):
    return tf.equal(label, 1)


ds1 = dataset.filter(f)
0 голосов
/ 28 марта 2019

Итак, если у вас есть набор данных (например, TFRecordDataset), вы можете отфильтровать примеры следующим образом:

  dataset = tf.data.TFRecordDataset(filenames=files)
  dataset = dataset.filter(lambda example: example["value"] == value and example["label"] == label)
  dataset = ...
...