Почему используется генератор данных, и в результате ошибка индекса 32 выходит за пределы оси 0 с размером 32 - PullRequest
0 голосов
/ 20 апреля 2019

Что случилось не так в моем коде. Спасибо за ответ на мой вопрос.

num_pairs = len(triples_data)
for i, (image_filename_l, image_filename_r, image_filename_n) in enumerate(triples_data):
    if i % 1000 == 0:
        print("images from {:d}/{:d} pairs loaded to cache".format(i, num_pairs))
    if image_filename_l not in image_cache:
        load_image_cache(image_cache, image_filename_l, IMAGE_DIR)
    if image_filename_r not in image_cache:
        load_image_cache(image_cache, image_filename_r, IMAGE_DIR)
    if image_filename_n not in image_cache:
        load_image_cache(image_cache, image_filename_n, IMAGE_DIR)
print("images from {:d}/{:d} pairs loaded to cache, COMPLETE".format(i, num_pairs))

# In[6]:
def pair_generator(triples, image_cache, datagens, batch_size=32):
    while True:
        # shuffle once per batch
        indices = np.random.permutation(np.arange(len(triples)))
        num_batches = len(triples) // batch_size
        for bid in range(num_batches):
            batch_indices = indices[bid * batch_size : (bid + 3) * batch_size]
            batch = [triples[i] for i in batch_indices]
            X1 = np.zeros((batch_size, 224, 224, 3))
            X2 = np.zeros((batch_size, 224, 224, 3))
            X3 = np.zeros((batch_size, 224, 224, 3))
            for i, (image_filename_l, image_filename_r, image_filename_n) in enumerate(batch):
                if datagens is None or len(datagens) == 0:
                    X1[i] = image_cache[image_filename_l]
                    X2[i] = image_cache[image_filename_r]
                    X3[i] = image_cache[image_filename_n]
                else:
                    X1[i] = datagens[0].random_transform(image_cache[image_filename_l])
                    X2[i] = datagens[1].random_transform(image_cache[image_filename_r])
                    X3[i] = datagens[2].random_transform(image_cache[image_filename_n])
            yield [X1, X2, X3]
datagen_args = dict(rotation_range=10,
                    width_shift_range=0.2,
                    height_shift_range=0.2,
                    zoom_range=0.2)
datagens = [ImageDataGenerator(**datagen_args),
            ImageDataGenerator(**datagen_args),
            ImageDataGenerator(**datagen_args)]
pair_gen = pair_generator(test_triples_data, image_cache, datagens, 32)
[X1, X2, X3] = pair_gen.__next__()

Traceback (последний вызов был последним): [X1, X2, X3] = pair_gen. следующий () X1 [i] = datagens [0] .random_transform (image_cache [image_filename_l]) IndexError: индекс 32 выходит за пределы оси 0 с размером 32

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...