Tfrecord получить непревзойденные парные данные - PullRequest
0 голосов
/ 13 марта 2019

У меня есть много совпадающих парных изображений из 2 доменов A и B. Я сохранил эти изображения в некоторые файлы tfrecord, но при загрузке парных данных из файлов они больше не совпадают.

Здесьмой код сохранения:

def save_tfrecords(paths, desdir):
    cnt_file_idx = 0
    cnt_data_idx = 0

    filename = os.path.join(desdir, 'data%d.tfrecords' % cnt_file_idx)
    filename_list = [filename]

    writer = tf.python_io.TFRecordWriter(filename)
    for i, path in enumerate(paths):
        data = np.load(path)
        data_shape = np.shape(data)

        width = data_shape[1]  # [height, width, channels]
        a_image = np.array(data[:, :width // 2])
        b_image = np.array(data[:, width // 2:])

        # until here I have got correct image pairs

        features = tf.train.Features(
            feature={
                "A": tf.train.Feature(float_list=tf.train.FloatList(value=a_image.reshape(-1))),
                "B": tf.train.Feature(float_list=tf.train.FloatList(value=b_image.reshape(-1))),
                "a_shape": tf.train.Feature(int64_list=tf.train.Int64List(value=np.shape(a_image))),
                "b_shape": tf.train.Feature(int64_list=tf.train.Int64List(value=np.shape(b_image)))
            }
        )
        example = tf.train.Example(features=features)
        serialized = example.SerializeToString()
        writer.write(serialized)

        cnt_data_idx += 1
        if cnt_data_idx == 500:
            writer.close()
            cnt_file_idx += 1
            cnt_data_idx = 0
            filename = os.path.join(desdir, 'data%d.tfrecords' % cnt_file_idx)
            filename_list.append(filename)
            writer = tf.python_io.TFRecordWriter(filename)
    writer.close()
    return filename_list

и мой код загрузки:

def load_example(path):  # return 2 iterator (not initialized)
    def pares_tf(example_proto):
        features = {
            "A": tf.VarLenFeature(dtype=tf.float32),
            "B": tf.VarLenFeature(dtype=tf.float32),

            "a_shape": tf.FixedLenFeature(shape=(2,), dtype=tf.int64),
            "b_shape": tf.FixedLenFeature(shape=(2,), dtype=tf.int64)
        }

        parsed_example = tf.parse_single_example(serialized=example_proto, features=features)

        parsed_example['A'] = tf.sparse_tensor_to_dense(parsed_example['A'])
        parsed_example['B'] = tf.sparse_tensor_to_dense(parsed_example['B'])

        parsed_example['A'] = tf.reshape(parsed_example['A'], parsed_example['a_shape'])
        parsed_example['A'] = tf.expand_dims(parsed_example['A'], -1)
        parsed_example['B'] = tf.reshape(parsed_example['B'], parsed_example['b_shape'])
        parsed_example['B'] = tf.expand_dims(parsed_example['B'], -1)

        return parsed_example

    tf_data_dir = os.path.join(path, 'train', 'pair', 'tf_data')
    tf_filename_list = glob.glob(os.path.join(tf_data_dir, "*.tfrecords"))

    dataset_train = tf.data.TFRecordDataset(tf_filename_list)
    dataset_train = dataset_train.map(pares_tf).repeat().batch(32)

    iterator_train = dataset_train.make_initializable_iterator()

    return iterator_train

Другая сбивающая с толку вещь заключается в том, что когда я устанавливаю размер пакета, равный количеству данных, сохраненных в tfrecords(здесь 500) при загрузке данных кажется, что пары изображений правильные, но если я установлю размер пакета равным 499, несоответствие произойдет с фиксированным расстоянием 1, то есть i-й изображение из домена A в паре с (i + 1) -ым изображением из домена B , и если размер пакета равен 498, несоответствиенаходится на расстоянии 2 (i в A в паре с i + 2 в B) и т. д.

Я не понимаю, почему это может произойти.Кто-нибудь может мне помочь с этим вопросом?

...