Я создал список из двух других списков, который выглядит следующим образом:
samples = list(map(lambda x, y: [x,y], image_path, labels8))
[['s01_l01/1_1.png', '7C2 4698'],
['s01_l01/2_1.png', '7C2 4698'],
['s01_l01/2_2.png', '7C2 4698'],
['s01_l01/2_3.png', '7C2 4698'],
['s01_l01/2_4.png', '7C2 4698']]
Первая запись - image_path, вторая - метка.
Я также создал эту функцию:
def shuffle_data(data):
data=random.shuffle(data)
return data
Чтобы получить data_generator, я изменил код, который нашел в видео на YouTube (https://www.youtube.com/watch?v=EkzB6PJIcCA&t=530s):
def data_generator(samples, batch_size=32, shuffle_data = True, resize=224):
num_samples = len(samples)
while True:
samples = random.shuffle(samples)
for offset in range(0, num_samples, batch_size):
batch_samples = samples[offset: offset + batch_size]
X_train = []
y_train = []
for batch_sample in batch_samples:
img_name = batch_sample[0]
label = batch_sample[1]
img = cv2.imread(os.path.join(root_dir, img_name))
#img, label = preprocessing(img, label, new_height=224, new_width=224, num_classes=37)
img = preprocessing(img, new_height=224, new_width=224)
label = my_onehot_encoded(label)
X_train.append(img)
y_train.append(label)
X_train = np.array(X_train)
y_train = np.array(y_train)
yield X_train, y_train
Когда я сейчас попробую для выполнения этого кода:
train_datagen = data_generator(samples, batch_size=32)
x, y = next(train_datagen)
print('x_shape: ', x.shape)
print('labels shape: ', y.shape)
print('labels: ', y)
Я получил следующий код ошибки:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-89-6adc7f4509cd> in <module>()
1 train_datagen = data_generator(samples, batch_size=32)
2
----> 3 x, y = next(train_datagen)
4 print('x_shape: ', x.shape)
5 print('labels shape: ', y.shape)
<ipython-input-88-0f34e3e5c990> in data_generator(samples, batch_size, shuffle_data, resize)
5
6 for offset in range(0, num_samples, batch_size):
----> 7 batch_samples = samples[offset: offset + batch_size]
8
9 X_train = []
TypeError: 'NoneType' object is not subscriptable
Я не понимаю, где ошибка ...