Генератор Keras и fit_generator, как создать генератор, чтобы избежать ошибки 'shape shape' - PullRequest
0 голосов
/ 29 марта 2019

Я создаю генератор для Keras, чтобы иметь возможность загружать изображения из моего набора данных, так как он немного велик для моего оперативной памяти.

Я построил генератор так:

# import the necessary packages
import tensorflow
from tensorflow import keras
from keras.preprocessing.image import ImageDataGenerator
import matplotlib.pyplot as plt
from sklearn.preprocessing import OneHotEncoder
import numpy as np
import pandas as pd
from tqdm import tqdm

#loading
path_to_txt = "/content/test/leafsnap-dataset/leafsnap-dataset- 
images_improved.txt"
df = pd.read_csv(path_to_txt ,sep='\t')
arr = np.array(df)
#epochs and steps:
NUM_TRAIN_IMAGES = 0
NUM_EPOCHS = 30

def image_generator(arr, bs, mode="train", aug=None):
  while True:
    images = []
    labels = []
    for row in arr:
      if len(images) < bs:
        img = (cv2.resize(cv2.imread("/content/test/leafsnap-dataset/" + 
        row[0]),(224,224)))
        images.append(img)
        labels.append([row[2]])
        NUM_TRAIN_IMAGES += 1
      else:
        break


  if aug is not None:
    (images, labels) = next(aug.flow(np.array(images),labels, 
     batch_size=bs))

  obj = OneHotEncoder()
  values = obj.fit_transform(labels).toarray()

  yield (np.array(images), labels)

Затем я вызываю fit_generator из последовательной модели (cnn работал до тех пор, пока не получил ошибку OOM)

#create the augmentation function:
 aug = ImageDataGenerator(rotation_range=20, zoom_range=0.15,
    width_shift_range=0.2, height_shift_range=0.2, shear_range=0.15,
    horizontal_flip=True, fill_mode="nearest")

#create the generator:
gen = image_generator(arr, bs = 32, mode = "train", aug = aug)

history = model.fit_generator(image_generator,
    steps_per_epoch = NUM_TRAIN_IMAGES,
    epochs = NUM_EPOCHS)

И отсюда я получаю эту ошибку:

# Create generator from NumPy or EagerTensor Input.
--> 377   num_samples = int(nest.flatten(data)[0].shape[0])
378   if batch_size is None:
379     raise ValueError('You must specify `batch_size`')
AttributeError: 'function' object has no attribute 'shape'

1 Ответ

1 голос
/ 29 марта 2019

Я вижу две основные ошибки здесь.

Во-первых, ваша функция генератора неэффективна для памяти. Потому что вы загружаете все изображения сначала (пока цикл). Вы должны перебрать файлы изображений и внутри цикла вывести np.array изображения с меткой.

Во-вторых, вы передаете имя функции генератора в fit_generator, когда вам нужно использовать его возвращенный объект - gen.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...